未验证 提交 9486fc66 编写于 作者: L lvmengsi 提交者: GitHub

fix bug for gan (#2329)

上级 b9845895
...@@ -26,7 +26,7 @@ import random ...@@ -26,7 +26,7 @@ import random
def RandomCrop(img, crop_w, crop_h): def RandomCrop(img, crop_w, crop_h):
w, h = img.shape[0], img.shape[1] w, h = img.size[0], img.size[1]
i = np.random.randint(0, w - crop_w) i = np.random.randint(0, w - crop_w)
j = np.random.randint(0, h - crop_h) j = np.random.randint(0, h - crop_h)
return img.crop((i, j, i + crop_w, j + crop_h)) return img.crop((i, j, i + crop_w, j + crop_h))
...@@ -346,7 +346,7 @@ class data_reader(object): ...@@ -346,7 +346,7 @@ class data_reader(object):
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
elif self.cfg.model_net == 'Pix2pix': else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None: if self.cfg.train_list is not None:
...@@ -372,22 +372,3 @@ class data_reader(object): ...@@ -372,22 +372,3 @@ class data_reader(object):
reader = train_reader.get_train_reader( reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle) self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num return reader, reader_test, batch_num
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.data_list is not None:
train_list = self.cfg.data_list
train_reader = reader_creator(
image_dir=dataset_dir, list_filename=train_list)
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=1,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
batch_num = train_reader.len()
return train_reader, reader_test, batch_num
...@@ -23,6 +23,7 @@ import argparse ...@@ -23,6 +23,7 @@ import argparse
import requests import requests
import six import six
import hashlib import hashlib
import zipfile
parser = argparse.ArgumentParser(description='Download dataset.') parser = argparse.ArgumentParser(description='Download dataset.')
#TODO add celeA dataset #TODO add celeA dataset
......
...@@ -31,7 +31,7 @@ parser = argparse.ArgumentParser(description=__doc__) ...@@ -31,7 +31,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('model_net', str, 'cgan', "The model used") add_arg('model_net', str, 'cgan', "The model used")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]") add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('input', str, None, "The images to be infered.") add_arg('input', str, None, "The images to be infered.")
add_arg('init_model', str, None, "The init model file of directory.") add_arg('init_model', str, None, "The init model file of directory.")
add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.") add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.")
......
python infer.py --init_model output/checkpoints/199/ --input data/cityscapes/testA/* --input_style A --model_net cyclegan --net_G resnet_6block --g_bash_dims 32
python infer.py --init_model output/chechpoints/15/ --input data/cityscapes/test/B/100.jpg --model_net Pix2pix --net_G unet_256 python infer.py --init_model output/chechpoints/199/ --input data/cityscapes/testB/* --model_net Pix2pix --net_G unet_256
import os
def make_pair_data(fileA, file):
f = open(fileA, 'r')
lines = f.readlines()
w = open(file, 'w')
for line in lines:
fileA = line[:-1]
print(fileA)
fileB = fileA.replace("A", "B")
print(fileB)
l = fileA + '\t' + fileB + '\n'
w.write(l)
w.close()
if __name__ == "__main__":
trainA_file = "./data/cityscapes/trainA.txt"
train_file = "./data/cityscapes/pix2pix_train_list"
make_pair_data(trainA_file, train_file)
testA_file = "./data/cityscapes/testA.txt"
test_file = "./data/cityscapes/pix2pix_test_list"
make_pair_data(testA_file, test_file)
python train.py --model_net CycleGAN --dataset cityscapes --batch_size 1 --net_G resnet_9block --g_base_dim 32 --net_D basic --norm_type batch_norm --epoch 200 --load_size 286 --crop_size 256 --crop_type Random > log_out 2>log_err
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list10 --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
...@@ -35,6 +35,7 @@ class GTrainer(): ...@@ -35,6 +35,7 @@ class GTrainer():
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = CGAN_model() model = CGAN_model()
self.fake = model.network_G(input, conditions, name="G") self.fake = model.network_G(input, conditions, name="G")
self.fake.persistable = True
self.infer_program = self.program.clone() self.infer_program = self.program.clone()
d_fake = model.network_D(self.fake, conditions, name="D") d_fake = model.network_D(self.fake, conditions, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like( fake_labels = fluid.layers.fill_constant_batch_size_like(
...@@ -42,6 +43,7 @@ class GTrainer(): ...@@ -42,6 +43,7 @@ class GTrainer():
self.g_loss = fluid.layers.reduce_mean( self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels)) x=d_fake, label=fake_labels))
self.g_loss.persistable = True
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
...@@ -62,7 +64,7 @@ class DTrainer(): ...@@ -62,7 +64,7 @@ class DTrainer():
self.d_loss = fluid.layers.reduce_mean( self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_logit, label=labels)) x=d_logit, label=labels))
self.d_loss.persistable = True
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
if fluid.io.is_parameter(var) and (var.name.startswith("D")): if fluid.io.is_parameter(var) and (var.name.startswith("D")):
...@@ -112,7 +114,7 @@ class CGAN(object): ...@@ -112,7 +114,7 @@ class CGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False build_strategy.memory_optimize = True
g_trainer_program = fluid.CompiledProgram( g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel( g_trainer.program).with_data_parallel(
......
...@@ -47,16 +47,20 @@ class GTrainer(): ...@@ -47,16 +47,20 @@ class GTrainer():
fluid.layers.elementwise_sub( fluid.layers.elementwise_sub(
x=input_B, y=self.cyc_B)) x=input_B, y=self.cyc_B))
self.cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A self.cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A
self.cyc_A_loss.persistable = True
self.cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B self.cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B
self.cyc_B_loss.persistable = True
self.cyc_loss = self.cyc_A_loss + self.cyc_B_loss self.cyc_loss = self.cyc_A_loss + self.cyc_B_loss
# GAN Loss D_A(G_A(A)) # GAN Loss D_A(G_A(A))
self.fake_rec_A = model.network_D(self.fake_B, name="DA", cfg=cfg) self.fake_rec_A = model.network_D(self.fake_B, name="DA", cfg=cfg)
self.G_A = fluid.layers.reduce_mean( self.G_A = fluid.layers.reduce_mean(
fluid.layers.square(self.fake_rec_A - 1)) fluid.layers.square(self.fake_rec_A - 1))
self.G_A.persistable = True
# GAN Loss D_B(G_B(B)) # GAN Loss D_B(G_B(B))
self.fake_rec_B = model.network_D(self.fake_A, name="DB", cfg=cfg) self.fake_rec_B = model.network_D(self.fake_A, name="DB", cfg=cfg)
self.G_B = fluid.layers.reduce_mean( self.G_B = fluid.layers.reduce_mean(
fluid.layers.square(self.fake_rec_B - 1)) fluid.layers.square(self.fake_rec_B - 1))
self.G_B.persistable = True
self.G = self.G_A + self.G_B self.G = self.G_A + self.G_B
# Identity Loss G_A # Identity Loss G_A
self.idt_A = model.network_G(input_B, name="GA", cfg=cfg) self.idt_A = model.network_G(input_B, name="GA", cfg=cfg)
...@@ -64,12 +68,14 @@ class GTrainer(): ...@@ -64,12 +68,14 @@ class GTrainer():
fluid.layers.abs( fluid.layers.abs(
fluid.layers.elementwise_sub( fluid.layers.elementwise_sub(
x=input_B, y=self.idt_A))) * lambda_B * lambda_identity x=input_B, y=self.idt_A))) * lambda_B * lambda_identity
self.idt_loss_A.persistable = True
# Identity Loss G_B # Identity Loss G_B
self.idt_B = model.network_G(input_A, name="GB", cfg=cfg) self.idt_B = model.network_G(input_A, name="GB", cfg=cfg)
self.idt_loss_B = fluid.layers.reduce_mean( self.idt_loss_B = fluid.layers.reduce_mean(
fluid.layers.abs( fluid.layers.abs(
fluid.layers.elementwise_sub( fluid.layers.elementwise_sub(
x=input_A, y=self.idt_B))) * lambda_A * lambda_identity x=input_A, y=self.idt_B))) * lambda_A * lambda_identity
self.idt_loss_B.persistable = True
self.idt_loss = fluid.layers.elementwise_add(self.idt_loss_A, self.idt_loss = fluid.layers.elementwise_add(self.idt_loss_A,
self.idt_loss_B) self.idt_loss_B)
...@@ -107,8 +113,8 @@ class DATrainer(): ...@@ -107,8 +113,8 @@ class DATrainer():
self.d_loss_A = (fluid.layers.square(self.fake_pool_rec_B) + self.d_loss_A = (fluid.layers.square(self.fake_pool_rec_B) +
fluid.layers.square(self.rec_B - 1)) / 2.0 fluid.layers.square(self.rec_B - 1)) / 2.0
self.d_loss_A = fluid.layers.reduce_mean(self.d_loss_A) self.d_loss_A = fluid.layers.reduce_mean(self.d_loss_A)
self.d_loss_A.persistable = True
optimizer = fluid.optimizer.Adam(learning_rate=0.0002, beta1=0.5)
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith("DA"): if fluid.io.is_parameter(var) and var.name.startswith("DA"):
...@@ -142,7 +148,7 @@ class DBTrainer(): ...@@ -142,7 +148,7 @@ class DBTrainer():
self.d_loss_B = (fluid.layers.square(self.fake_pool_rec_A) + self.d_loss_B = (fluid.layers.square(self.fake_pool_rec_A) +
fluid.layers.square(self.rec_A - 1)) / 2.0 fluid.layers.square(self.rec_A - 1)) / 2.0
self.d_loss_B = fluid.layers.reduce_mean(self.d_loss_B) self.d_loss_B = fluid.layers.reduce_mean(self.d_loss_B)
optimizer = fluid.optimizer.Adam(learning_rate=0.0002, beta1=0.5) self.d_loss_B.persistable = True
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith("DB"): if fluid.io.is_parameter(var) and var.name.startswith("DB"):
...@@ -230,8 +236,8 @@ class CycleGAN(object): ...@@ -230,8 +236,8 @@ class CycleGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = True
build_strategy.memory_optimize = False build_strategy.memory_optimize = True
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
......
...@@ -35,6 +35,7 @@ class GTrainer(): ...@@ -35,6 +35,7 @@ class GTrainer():
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
model = DCGAN_model() model = DCGAN_model()
self.fake = model.network_G(input, name='G') self.fake = model.network_G(input, name='G')
self.fake.persistable = True
self.infer_program = self.program.clone() self.infer_program = self.program.clone()
d_fake = model.network_D(self.fake, name="D") d_fake = model.network_D(self.fake, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like( fake_labels = fluid.layers.fill_constant_batch_size_like(
...@@ -42,6 +43,7 @@ class GTrainer(): ...@@ -42,6 +43,7 @@ class GTrainer():
self.g_loss = fluid.layers.reduce_mean( self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels)) x=d_fake, label=fake_labels))
self.g_loss.persistable = True
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
...@@ -61,6 +63,7 @@ class DTrainer(): ...@@ -61,6 +63,7 @@ class DTrainer():
self.d_loss = fluid.layers.reduce_mean( self.d_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_logit, label=labels)) x=d_logit, label=labels))
self.d_loss.persistable = True
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
if fluid.io.is_parameter(var) and (var.name.startswith("D")): if fluid.io.is_parameter(var) and (var.name.startswith("D")):
...@@ -78,7 +81,7 @@ class DCGAN(object): ...@@ -78,7 +81,7 @@ class DCGAN(object):
return parser return parser
def __init__(self, cfg, train_reader): def __init__(self, cfg=None, train_reader=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
...@@ -107,7 +110,7 @@ class DCGAN(object): ...@@ -107,7 +110,7 @@ class DCGAN(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = False build_strategy.memory_optimize = True
g_trainer_program = fluid.CompiledProgram( g_trainer_program = fluid.CompiledProgram(
g_trainer.program).with_data_parallel( g_trainer.program).with_data_parallel(
......
...@@ -68,16 +68,15 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): ...@@ -68,16 +68,15 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
def base_parse_args(parser): def base_parse_args(parser):
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('model_net', str, "cgan", "The model used.") add_arg('model_net', str, "CGAN", "The model used.")
add_arg('dataset', str, "mnist", "The dataset used.") add_arg('dataset', str, "mnist", "The dataset used.")
add_arg('data_dir', str, "./data", "The dataset root directory") add_arg('data_dir', str, "./data", "The dataset root directory")
add_arg('data_list', str, "data/cityscapes/pix2pix_train_list", "The data list file name") add_arg('train_list', str, None, "The train list file name")
add_arg('train_list', str, "data/cityscapes/pix2pix_train_list", "The train list file name") add_arg('test_list', str, None, "The test list file name")
add_arg('test_list', str, "data/cityscapes/pix2pix_test_list10", "The test list file name")
add_arg('batch_size', int, 1, "Minibatch size.") add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('epoch', int, 200, "The number of epoch to be trained.") add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator") add_arg('g_base_dims', int, 64, "Base channels in generator")
add_arg('d_base_dims', int, 64, "Base channels in CycleGAN discriminator") add_arg('d_base_dims', int, 64, "Base channels in discriminator")
add_arg('load_size', int, 286, "the image size when load the image") add_arg('load_size', int, 286, "the image size when load the image")
add_arg('crop_type', str, 'Centor', add_arg('crop_type', str, 'Centor',
"the crop type, choose = ['Centor', 'Random']") "the crop type, choose = ['Centor', 'Random']")
...@@ -96,7 +95,7 @@ def base_parse_args(parser): ...@@ -96,7 +95,7 @@ def base_parse_args(parser):
add_arg('gan_mode', str, "vanilla", "The init model file of directory.") add_arg('gan_mode', str, "vanilla", "The init model file of directory.")
add_arg('norm_type', str, "batch_norm", "Which normalization to used") add_arg('norm_type', str, "batch_norm", "Which normalization to used")
add_arg('learning_rate', float, 0.0002, "the initialize learning rate") add_arg('learning_rate', float, 0.0002, "the initialize learning rate")
add_arg('lambda_L1', float, 100.0, "the initialize learning rate") add_arg('lambda_L1', float, 100.0, "the initialize lambda parameter for L1 loss")
add_arg('num_generator_time', int, 1, add_arg('num_generator_time', int, 1,
"the generator run times in training each epoch") "the generator run times in training each epoch")
add_arg('print_freq', int, 10, "the frequency of print loss") add_arg('print_freq', int, 10, "the frequency of print loss")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册