未验证 提交 9486fc66 编写于 作者: L lvmengsi 提交者: GitHub

fix bug for gan (#2329)

上级 b9845895
......@@ -26,7 +26,7 @@ import random
def RandomCrop(img, crop_w, crop_h):
w, h = img.shape[0], img.shape[1]
w, h = img.size[0], img.size[1]
i = np.random.randint(0, w - crop_w)
j = np.random.randint(0, h - crop_h)
return img.crop((i, j, i + crop_w, j + crop_h))
......@@ -346,7 +346,7 @@ class data_reader(object):
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
elif self.cfg.model_net == 'Pix2pix':
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
......@@ -372,22 +372,3 @@ class data_reader(object):
reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.data_list is not None:
train_list = self.cfg.data_list
train_reader = reader_creator(
image_dir=dataset_dir, list_filename=train_list)
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=1,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
batch_num = train_reader.len()
return train_reader, reader_test, batch_num
......@@ -23,6 +23,7 @@ import argparse
import requests
import six
import hashlib
import zipfile
parser = argparse.ArgumentParser(description='Download dataset.')
#TODO add celeA dataset
......
......@@ -31,7 +31,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_net', str, 'cgan', "The model used")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('input', str, None, "The images to be infered.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.")
......
python infer.py --init_model output/checkpoints/199/ --input data/cityscapes/testA/* --input_style A --model_net cyclegan --net_G resnet_6block --g_bash_dims 32
python infer.py --init_model output/chechpoints/15/ --input data/cityscapes/test/B/100.jpg --model_net Pix2pix --net_G unet_256
python infer.py --init_model output/chechpoints/199/ --input data/cityscapes/testB/* --model_net Pix2pix --net_G unet_256
import os
def make_pair_data(fileA, file):
f = open(fileA, 'r')
lines = f.readlines()
w = open(file, 'w')
for line in lines:
fileA = line[:-1]
print(fileA)
fileB = fileA.replace("A", "B")
print(fileB)
l = fileA + '\t' + fileB + '\n'
w.write(l)
w.close()
if __name__ == "__main__":
trainA_file = "./data/cityscapes/trainA.txt"
train_file = "./data/cityscapes/pix2pix_train_list"
make_pair_data(trainA_file, train_file)
testA_file = "./data/cityscapes/testA.txt"
test_file = "./data/cityscapes/pix2pix_test_list"
make_pair_data(testA_file, test_file)
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list10 --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
......@@ -35,6 +35,7 @@ class GTrainer():
with fluid.program_guard(self.program):
model = CGAN_model()
self.fake = model.network_G(input, conditions, name="G")
self.fake.persistable = True
self.infer_program = self.program.clone()
d_fake = model.network_D(self.fake, conditions, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like(
......@@ -42,6 +43,7 @@ class GTrainer():
self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels))
self.g_loss.persistable = True
vars = []
for var in self.program.list_vars():
......@@ -62,7 +64,7 @@ class DTrainer():
self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_logit, label=labels))
self.d_loss.persistable = True
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and (var.name.startswith("D")):
......@@ -112,7 +114,7 @@ class CGAN(object):
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
build_strategy.memory_optimize = True
g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel(
......
......@@ -47,16 +47,20 @@ class GTrainer():
fluid.layers.elementwise_sub(
x=input_B, y=self.cyc_B))
self.cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A
self.cyc_A_loss.persistable = True
self.cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B
self.cyc_B_loss.persistable = True
self.cyc_loss = self.cyc_A_loss + self.cyc_B_loss
# GAN Loss D_A(G_A(A))
self.fake_rec_A = model.network_D(self.fake_B, name="DA", cfg=cfg)
self.G_A = fluid.layers.reduce_mean(
fluid.layers.square(self.fake_rec_A - 1))
self.G_A.persistable = True
# GAN Loss D_B(G_B(B))
self.fake_rec_B = model.network_D(self.fake_A, name="DB", cfg=cfg)
self.G_B = fluid.layers.reduce_mean(
fluid.layers.square(self.fake_rec_B - 1))
self.G_B.persistable = True
self.G = self.G_A + self.G_B
# Identity Loss G_A
self.idt_A = model.network_G(input_B, name="GA", cfg=cfg)
......@@ -64,12 +68,14 @@ class GTrainer():
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=input_B, y=self.idt_A))) * lambda_B * lambda_identity
self.idt_loss_A.persistable = True
# Identity Loss G_B
self.idt_B = model.network_G(input_A, name="GB", cfg=cfg)
self.idt_loss_B = fluid.layers.reduce_mean(
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=input_A, y=self.idt_B))) * lambda_A * lambda_identity
self.idt_loss_B.persistable = True
self.idt_loss = fluid.layers.elementwise_add(self.idt_loss_A,
self.idt_loss_B)
......@@ -107,8 +113,8 @@ class DATrainer():
self.d_loss_A = (fluid.layers.square(self.fake_pool_rec_B) +
fluid.layers.square(self.rec_B - 1)) / 2.0
self.d_loss_A = fluid.layers.reduce_mean(self.d_loss_A)
self.d_loss_A.persistable = True
optimizer = fluid.optimizer.Adam(learning_rate=0.0002, beta1=0.5)
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith("DA"):
......@@ -142,7 +148,7 @@ class DBTrainer():
self.d_loss_B = (fluid.layers.square(self.fake_pool_rec_A) +
fluid.layers.square(self.rec_A - 1)) / 2.0
self.d_loss_B = fluid.layers.reduce_mean(self.d_loss_B)
optimizer = fluid.optimizer.Adam(learning_rate=0.0002, beta1=0.5)
self.d_loss_B.persistable = True
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith("DB"):
......@@ -230,8 +236,8 @@ class CycleGAN(object):
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel(
......
......@@ -35,6 +35,7 @@ class GTrainer():
with fluid.program_guard(self.program):
model = DCGAN_model()
self.fake = model.network_G(input, name='G')
self.fake.persistable = True
self.infer_program = self.program.clone()
d_fake = model.network_D(self.fake, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like(
......@@ -42,6 +43,7 @@ class GTrainer():
self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels))
self.g_loss.persistable = True
vars = []
for var in self.program.list_vars():
......@@ -61,6 +63,7 @@ class DTrainer():
self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_logit, label=labels))
self.d_loss.persistable = True
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and (var.name.startswith("D")):
......@@ -78,7 +81,7 @@ class DCGAN(object):
return parser
def __init__(self, cfg, train_reader):
def __init__(self, cfg=None, train_reader=None):
self.cfg = cfg
self.train_reader = train_reader
......@@ -107,7 +110,7 @@ class DCGAN(object):
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
build_strategy.memory_optimize = True
g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel(
......
......@@ -68,16 +68,15 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
def base_parse_args(parser):
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_net', str, "cgan", "The model used.")
add_arg('model_net', str, "CGAN", "The model used.")
add_arg('dataset', str, "mnist", "The dataset used.")
add_arg('data_dir', str, "./data", "The dataset root directory")
add_arg('data_list', str, "data/cityscapes/pix2pix_train_list", "The data list file name")
add_arg('train_list', str, "data/cityscapes/pix2pix_train_list", "The train list file name")
add_arg('test_list', str, "data/cityscapes/pix2pix_test_list10", "The test list file name")
add_arg('train_list', str, None, "The train list file name")
add_arg('test_list', str, None, "The test list file name")
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator")
add_arg('d_base_dims', int, 64, "Base channels in CycleGAN discriminator")
add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']")
......@@ -96,7 +95,7 @@ def base_parse_args(parser):
add_arg('gan_mode', str, "vanilla", "The init model file of directory.")
add_arg('norm_type', str, "batch_norm", "Which normalization to used")
add_arg('learning_rate', float, 0.0002, "the initialize learning rate")
add_arg('lambda_L1', float, 100.0, "the initialize learning rate")
add_arg('lambda_L1', float, 100.0, "the initialize lambda parameter for L1 loss")
add_arg('num_generator_time', int, 1,
"the generator run times in training each epoch")
add_arg('print_freq', int, 10, "the frequency of print loss")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册