未验证 提交 4a93b001 编写于 作者: L lvmengsi 提交者: GitHub

Update gan (#2871)

* refine gan
上级 7b1a5565
此差异已折叠。
...@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size") ...@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size")
add_arg('selected_attrs', str, add_arg('selected_attrs', str,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young", "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change") "the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test") add_arg('n_samples', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file") add_arg('test_list', str, "./data/celeba/list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered") add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor") add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor") add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
...@@ -149,11 +149,9 @@ def infer(args): ...@@ -149,11 +149,9 @@ def infer(args):
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
list_filename=args.test_list, list_filename=args.test_list,
batch_size=args.batch_size, args=args,
drop_last=False, mode="VAL")
args=args) reader_test = test_reader.make_reader(return_name=True)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
print("read {}".format(name)) print("read {}".format(name))
...@@ -199,11 +197,9 @@ def infer(args): ...@@ -199,11 +197,9 @@ def infer(args):
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
list_filename=args.test_list, list_filename=args.test_list,
batch_size=args.batch_size, args=args,
drop_last=False, mode="VAL")
args=args) reader_test = test_reader.make_reader(return_name=True)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
print("read {}".format(name)) print("read {}".format(name))
...@@ -256,9 +252,9 @@ def infer(args): ...@@ -256,9 +252,9 @@ def infer(args):
elif args.model_net == 'CGAN': elif args.model_net == 'CGAN':
noise_data = np.random.uniform( noise_data = np.random.uniform(
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32') size=[args.n_samples, args.noise_size]).astype('float32')
label = np.random.randint( label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32') 0, 9, size=[args.n_samples, 1]).astype('float32')
noise_tensor = fluid.LoDTensor() noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor() conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place) noise_tensor.set(noise_data, place)
...@@ -267,7 +263,7 @@ def infer(args): ...@@ -267,7 +263,7 @@ def infer(args):
fetch_list=[fake.name], fetch_list=[fake.name],
feed={"noise": noise_tensor, feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0] "conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig( plt.savefig(
...@@ -277,12 +273,12 @@ def infer(args): ...@@ -277,12 +273,12 @@ def infer(args):
elif args.model_net == 'DCGAN': elif args.model_net == 'DCGAN':
noise_data = np.random.uniform( noise_data = np.random.uniform(
low=-1.0, high=1.0, low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32') size=[args.n_samples, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor() noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place) noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name], fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0] feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig( plt.savefig(
......
...@@ -71,10 +71,10 @@ class AttGAN_model(object): ...@@ -71,10 +71,10 @@ class AttGAN_model(object):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
#SAME padding #SAME padding
z = conv2d( z = conv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
norm='batch_norm', norm='batch_norm',
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -104,10 +104,10 @@ class AttGAN_model(object): ...@@ -104,10 +104,10 @@ class AttGAN_model(object):
if i < n_layers - 1: if i < n_layers - 1:
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM) d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
z = deconv2d( z = deconv2d(
z, input=z,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
norm='batch_norm', norm='batch_norm',
...@@ -121,10 +121,10 @@ class AttGAN_model(object): ...@@ -121,10 +121,10 @@ class AttGAN_model(object):
z = self.concat(z, a) z = self.concat(z, a)
else: else:
x = z = deconv2d( x = z = deconv2d(
z, input=z,
3, num_filters=3,
4, filter_size=4,
2, stride=2,
padding_type='SAME', padding_type='SAME',
name=name + str(i), name=name + str(i),
activation_fn='tanh', activation_fn='tanh',
...@@ -146,10 +146,10 @@ class AttGAN_model(object): ...@@ -146,10 +146,10 @@ class AttGAN_model(object):
for i in range(n_layers): for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM) d = min(dim * 2**i, MAX_DIM)
y = conv2d( y = conv2d(
y, input=y,
d, num_filters=d,
4, filter_size=4,
2, stride=2,
norm=norm, norm=norm,
padding=1, padding=1,
activation_fn='leaky_relu', activation_fn='leaky_relu',
...@@ -159,8 +159,8 @@ class AttGAN_model(object): ...@@ -159,8 +159,8 @@ class AttGAN_model(object):
initial='kaiming') initial='kaiming')
logit_gan = linear( logit_gan = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='relu', activation_fn='relu',
name=name + 'fc_adv_1', name=name + 'fc_adv_1',
initial='kaiming') initial='kaiming')
...@@ -168,8 +168,8 @@ class AttGAN_model(object): ...@@ -168,8 +168,8 @@ class AttGAN_model(object):
logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming') logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')
logit_att = linear( logit_att = linear(
y, input=y,
fc_dim, output_size=fc_dim,
activation_fn='relu', activation_fn='relu',
name=name + 'fc_cls_1', name=name + 'fc_cls_1',
initial='kaiming') initial='kaiming')
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128 python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir "data/celeba/" --image_size 128
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 >log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --image_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 --epoch 200 --image_size 286 --crop_size 256 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
...@@ -24,51 +24,38 @@ import time ...@@ -24,51 +24,38 @@ import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import trainer
def train(cfg): def train(cfg):
MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN"
]
if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net))
reader = data_reader(cfg) reader = data_reader(cfg)
if cfg.model_net == 'CycleGAN':
if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data( a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
) )
elif cfg.model_net == 'Pix2pix':
train_reader, test_reader, batch_num = reader.make_data()
elif cfg.model_net == 'StarGAN':
train_reader, test_reader, batch_num = reader.make_data()
else: else:
if cfg.dataset == 'mnist': if cfg.dataset in ['mnist']:
train_reader = reader.make_data() train_reader = reader.make_data()
else: else:
train_reader, test_reader, batch_num = reader.make_data() train_reader, test_reader, batch_num = reader.make_data()
if cfg.model_net == 'CGAN': if cfg.model_net in ['CGAN', 'DCGAN']:
from trainer.CGAN import CGAN
if cfg.dataset != 'mnist':
raise NotImplementedError('CGAN only support mnist now!')
model = CGAN(cfg, train_reader)
elif cfg.model_net == 'DCGAN':
from trainer.DCGAN import DCGAN
if cfg.dataset != 'mnist': if cfg.dataset != 'mnist':
raise NotImplementedError('DCGAN only support mnist now!') raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = DCGAN(cfg, train_reader) model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net == 'CycleGAN': elif cfg.model_net in ['CycleGAN']:
from trainer.CycleGAN import CycleGAN model = trainer.__dict__[cfg.model_net](
model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test, cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num)
batch_num)
elif cfg.model_net == 'Pix2pix':
from trainer.Pix2pix import Pix2pix
model = Pix2pix(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'StarGAN':
from trainer.StarGAN import StarGAN
model = StarGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'AttGAN':
from trainer.AttGAN import AttGAN
model = AttGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'STGAN':
from trainer.STGAN import STGAN
model = STGAN(cfg, train_reader, test_reader, batch_num)
else: else:
pass model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num)
model.build_model() model.build_model()
...@@ -77,13 +64,13 @@ if __name__ == "__main__": ...@@ -77,13 +64,13 @@ if __name__ == "__main__":
cfg = config.parse_args() cfg = config.parse_args()
config.print_arguments(cfg) config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu) utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile: if cfg.profile:
if cfg.use_gpu: if cfg.use_gpu:
with profiler.profiler('All', 'total', '/tmp/profile') as prof: with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof:
train(cfg) train(cfg)
else: else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof: with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(cfg) train(cfg)
else: else:
train(cfg) train(cfg)
...@@ -175,8 +175,6 @@ class DTrainer(): ...@@ -175,8 +175,6 @@ class DTrainer():
class AttGAN(object): class AttGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', '--g_lr',
type=float, type=float,
......
...@@ -173,8 +173,6 @@ class DTrainer(): ...@@ -173,8 +173,6 @@ class DTrainer():
class STGAN(object): class STGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', '--g_lr',
type=float, type=float,
......
...@@ -199,8 +199,6 @@ class DTrainer(): ...@@ -199,8 +199,6 @@ class DTrainer():
class StarGAN(object): class StarGAN(object):
def add_special_args(self, parser): def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument( parser.add_argument(
'--g_lr', type=float, default=0.0001, help="learning rate of g") '--g_lr', type=float, default=0.0001, help="learning rate of g")
parser.add_argument( parser.add_argument(
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from .CGAN import CGAN
from .DCGAN import DCGAN
from .CycleGAN import CycleGAN
from .Pix2pix import Pix2pix
from .STGAN import STGAN
from .StarGAN import StarGAN
from .AttGAN import AttGAN
import importlib import importlib
......
...@@ -77,7 +77,7 @@ def base_parse_args(parser): ...@@ -77,7 +77,7 @@ def base_parse_args(parser):
add_arg('epoch', int, 200, "The number of epoch to be trained.") add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in generator") add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in discriminator") add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image") add_arg('image_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor', add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']") "the crop type, choose = ['Centor', 'Random']")
add_arg('crop_size', int, 256, "crop size when preprocess image") add_arg('crop_size', int, 256, "crop size when preprocess image")
......
...@@ -113,9 +113,10 @@ def save_test_image(epoch, ...@@ -113,9 +113,10 @@ def save_test_image(epoch,
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org) label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i] for j in range(len(label_org)):
label_trg = check_attribute_conflict(label_trg_tmp, label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
attr_names[i], attr_names) label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run( fake_temp, rec_temp = exe.run(
...@@ -126,11 +127,13 @@ def save_test_image(epoch, ...@@ -126,11 +127,13 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg "label_trg": tensor_label_trg
}, },
fetch_list=[g_trainer.fake_img, g_trainer.rec_img]) fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp[0]) fake_temp = save_batch_image(fake_temp)
rec_temp = save_batch_image(rec_temp[0]) rec_temp = save_batch_image(rec_temp)
images.append(fake_temp) images.append(fake_temp)
images.append(rec_temp) images.append(rec_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0], imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0],
((images_concat + 1) * 127.5).astype(np.uint8)) ((images_concat + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
...@@ -184,12 +187,12 @@ def save_test_image(epoch, ...@@ -184,12 +187,12 @@ def save_test_image(epoch,
else: else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()): for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[0][1] A_data, A_name = data_A
B_name = data_B[0][1] B_data, B_name = data_B
tensor_A = fluid.LoDTensor() tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor() tensor_B = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place) tensor_A.set(A_data, place)
tensor_B.set(data_B[0][0], place) tensor_B.set(B_data, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program, test_program,
fetch_list=[ fetch_list=[
...@@ -205,17 +208,19 @@ def save_test_image(epoch, ...@@ -205,17 +208,19 @@ def save_test_image(epoch,
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0],
(fake_B_temp + 1) * 127.5).astype(np.uint8)) ((fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0],
(fake_A_temp + 1) * 127.5).astype(np.uint8)) ((fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0],
(cyc_A_temp + 1) * 127.5).astype(np.uint8)) ((cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0],
(cyc_B_temp + 1) * 127.5).astype(np.uint8)) ((cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( imageio.imwrite(
out_path + "/inputA_" + str(epoch) + "_" + A_name[0], (
(input_A_temp + 1) * 127.5).astype(np.uint8)) (input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], (
(input_B_temp + 1) * 127.5).astype(np.uint8)) (input_B_temp + 1) * 127.5).astype(np.uint8))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册