未验证 提交 4a93b001 编写于 作者: L lvmengsi 提交者: GitHub

Update gan (#2871)

* refine gan
上级 7b1a5565
此差异已折叠。
......@@ -54,8 +54,8 @@ add_arg('image_size', int, 128, "image size")
add_arg('selected_attrs', str,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file")
add_arg('n_samples', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
......@@ -149,11 +149,9 @@ def infer(args):
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
batch_size=args.batch_size,
drop_last=False,
args=args)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
......@@ -199,11 +197,9 @@ def infer(args):
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
batch_size=args.batch_size,
drop_last=False,
args=args)
reader_test = test_reader.get_test_reader(
args, shuffle=False, return_name=True)
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
......@@ -256,9 +252,9 @@ def infer(args):
elif args.model_net == 'CGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
size=[args.n_samples, args.noise_size]).astype('float32')
label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32')
0, 9, size=[args.n_samples, 1]).astype('float32')
noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
......@@ -267,7 +263,7 @@ def infer(args):
fetch_list=[fake.name],
feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image)
plt.savefig(
......@@ -277,12 +273,12 @@ def infer(args):
elif args.model_net == 'DCGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
size=[args.n_samples, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fake_image = np.reshape(fake_temp, (args.n_samples, -1))
fig = utility.plot(fake_image)
plt.savefig(
......
......@@ -71,10 +71,10 @@ class AttGAN_model(object):
d = min(dim * 2**i, MAX_DIM)
#SAME padding
z = conv2d(
z,
d,
4,
2,
input=z,
num_filters=d,
filter_size=4,
stride=2,
padding_type='SAME',
norm='batch_norm',
activation_fn='leaky_relu',
......@@ -104,10 +104,10 @@ class AttGAN_model(object):
if i < n_layers - 1:
d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
z = deconv2d(
z,
d,
4,
2,
input=z,
num_filters=d,
filter_size=4,
stride=2,
padding_type='SAME',
name=name + str(i),
norm='batch_norm',
......@@ -121,10 +121,10 @@ class AttGAN_model(object):
z = self.concat(z, a)
else:
x = z = deconv2d(
z,
3,
4,
2,
input=z,
num_filters=3,
filter_size=4,
stride=2,
padding_type='SAME',
name=name + str(i),
activation_fn='tanh',
......@@ -146,10 +146,10 @@ class AttGAN_model(object):
for i in range(n_layers):
d = min(dim * 2**i, MAX_DIM)
y = conv2d(
y,
d,
4,
2,
input=y,
num_filters=d,
filter_size=4,
stride=2,
norm=norm,
padding=1,
activation_fn='leaky_relu',
......@@ -159,8 +159,8 @@ class AttGAN_model(object):
initial='kaiming')
logit_gan = linear(
y,
fc_dim,
input=y,
output_size=fc_dim,
activation_fn='relu',
name=name + 'fc_adv_1',
initial='kaiming')
......@@ -168,8 +168,8 @@ class AttGAN_model(object):
logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')
logit_att = linear(
y,
fc_dim,
input=y,
output_size=fc_dim,
activation_fn='relu',
name=name + 'fc_cls_1',
initial='kaiming')
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128
python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir "data/celeba/" --image_size 128
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 90 >log_out 2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 120 >log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --image_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 --epoch 200 --image_size 286 --crop_size 256 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 50 >log_out 2>log_err
......@@ -24,51 +24,38 @@ import time
import numpy as np
import paddle
import paddle.fluid as fluid
import trainer
def train(cfg):
MODELS = [
"CGAN", "DCGAN", "Pix2pix", "CycleGAN", "StarGAN", "AttGAN", "STGAN"
]
if cfg.model_net not in MODELS:
raise NotImplementedError("{} is not support!".format(cfg.model_net))
reader = data_reader(cfg)
if cfg.model_net == 'CycleGAN':
if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
)
elif cfg.model_net == 'Pix2pix':
train_reader, test_reader, batch_num = reader.make_data()
elif cfg.model_net == 'StarGAN':
train_reader, test_reader, batch_num = reader.make_data()
else:
if cfg.dataset == 'mnist':
if cfg.dataset in ['mnist']:
train_reader = reader.make_data()
else:
train_reader, test_reader, batch_num = reader.make_data()
if cfg.model_net == 'CGAN':
from trainer.CGAN import CGAN
if cfg.dataset != 'mnist':
raise NotImplementedError('CGAN only support mnist now!')
model = CGAN(cfg, train_reader)
elif cfg.model_net == 'DCGAN':
from trainer.DCGAN import DCGAN
if cfg.model_net in ['CGAN', 'DCGAN']:
if cfg.dataset != 'mnist':
raise NotImplementedError('DCGAN only support mnist now!')
model = DCGAN(cfg, train_reader)
elif cfg.model_net == 'CycleGAN':
from trainer.CycleGAN import CycleGAN
model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test,
batch_num)
elif cfg.model_net == 'Pix2pix':
from trainer.Pix2pix import Pix2pix
model = Pix2pix(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'StarGAN':
from trainer.StarGAN import StarGAN
model = StarGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'AttGAN':
from trainer.AttGAN import AttGAN
model = AttGAN(cfg, train_reader, test_reader, batch_num)
elif cfg.model_net == 'STGAN':
from trainer.STGAN import STGAN
model = STGAN(cfg, train_reader, test_reader, batch_num)
raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net in ['CycleGAN']:
model = trainer.__dict__[cfg.model_net](
cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num)
else:
pass
model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num)
model.build_model()
......@@ -77,13 +64,13 @@ if __name__ == "__main__":
cfg = config.parse_args()
config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile:
if cfg.use_gpu:
with profiler.profiler('All', 'total', '/tmp/profile') as prof:
with fluid.profiler.profiler('All', 'total',
'/tmp/profile') as prof:
train(cfg)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
with fluid.profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(cfg)
else:
train(cfg)
......@@ -175,8 +175,6 @@ class DTrainer():
class AttGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr',
type=float,
......
......@@ -173,8 +173,6 @@ class DTrainer():
class STGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr',
type=float,
......
......@@ -199,8 +199,6 @@ class DTrainer():
class StarGAN(object):
def add_special_args(self, parser):
parser.add_argument(
'--image_size', type=int, default=256, help="image size")
parser.add_argument(
'--g_lr', type=float, default=0.0001, help="learning rate of g")
parser.add_argument(
......
......@@ -12,6 +12,14 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from .CGAN import CGAN
from .DCGAN import DCGAN
from .CycleGAN import CycleGAN
from .Pix2pix import Pix2pix
from .STGAN import STGAN
from .StarGAN import StarGAN
from .AttGAN import AttGAN
import importlib
......
......@@ -77,7 +77,7 @@ def base_parse_args(parser):
add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image")
add_arg('image_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']")
add_arg('crop_size', int, 256, "crop size when preprocess image")
......
......@@ -113,9 +113,10 @@ def save_test_image(epoch,
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i]
label_trg = check_attribute_conflict(label_trg_tmp,
attr_names[i], attr_names)
for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -126,11 +127,13 @@ def save_test_image(epoch,
"label_trg": tensor_label_trg
},
fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp[0])
rec_temp = save_batch_image(rec_temp[0])
fake_temp = save_batch_image(fake_temp)
rec_temp = save_batch_image(rec_temp)
images.append(fake_temp)
images.append(rec_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0],
((images_concat + 1) * 127.5).astype(np.uint8))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
......@@ -184,12 +187,12 @@ def save_test_image(epoch,
else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[0][1]
B_name = data_B[0][1]
A_data, A_name = data_A
B_data, B_name = data_B
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place)
tensor_B.set(data_B[0][0], place)
tensor_A.set(A_data, place)
tensor_B.set(B_data, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program,
fetch_list=[
......@@ -205,18 +208,20 @@ def save_test_image(epoch,
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
(input_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0],
((fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0],
((fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0],
((cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0],
((cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputA_" + str(epoch) + "_" + A_name[0], (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], (
(input_B_temp + 1) * 127.5).astype(np.uint8))
class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册