提交 ad1a917f 编写于 作者: L lvmengsi 提交者: ruri

fix spade typo (#3365)

上级 c82ab203
......@@ -23,6 +23,8 @@ import time
import network.vgg as vgg
import pickle as pkl
import numpy as np
class GTrainer():
def __init__(self, input_label, input_img, input_ins, cfg, step_per_epoch):
self.cfg = cfg
......@@ -43,8 +45,10 @@ class GTrainer():
self.pred_fake = []
self.pred_real = []
for p in pred:
self.pred_fake.append([tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append([tensor[tensor.shape[0] // 2:] for tensor in p])
self.pred_fake.append(
[tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append(
[tensor[tensor.shape[0] // 2:] for tensor in p])
else:
self.pred_fake = pred[:pred.shape[0] // 2]
self.pred_real = pred[pred.shape[0] // 2:]
......@@ -67,20 +71,26 @@ class GTrainer():
for i in range(num_D):
num_intermediate_outputs = len(self.pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
self.gan_feat_loss = fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub(
x=self.pred_fake[i][j], y=self.pred_real[i][j]))) * cfg.lambda_feat / num_D
self.gan_feat_loss = fluid.layers.reduce_mean(
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=self.pred_fake[i][j], y=self.pred_real[i][
j]))) * cfg.lambda_feat / num_D
self.gan_feat_loss.persistable = True
########VGG Feat loss
weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
self.vgg = vgg.VGG19()
fake_vgg = self.vgg.net(self.fake_B)
real_vgg = self.vgg.net(input_img)
self.vgg_loss = 0.0
for i in range(len(fake_vgg)):
self.vgg_loss += weights[i] * fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub(
x=fake_vgg[i], y=real_vgg[i])))
self.vgg_loss += weights[i] * fluid.layers.reduce_mean(
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=fake_vgg[i], y=real_vgg[i])))
self.vgg_loss.persistable = True
self.g_loss = (self.gan_loss + self.gan_feat_loss + self.vgg_loss)/3
self.g_loss = (
self.gan_loss + self.gan_feat_loss + self.vgg_loss) / 3
lr = cfg.learning_rate
vars = []
for var in self.program.list_vars():
......@@ -109,7 +119,8 @@ class GTrainer():
class DTrainer():
def __init__(self, input_label, input_img, input_ins, fake_B, cfg, step_per_epoch):
def __init__(self, input_label, input_img, input_ins, fake_B, cfg,
step_per_epoch):
self.program = fluid.default_main_program().clone()
lr = cfg.learning_rate
with fluid.program_guard(self.program):
......@@ -125,8 +136,10 @@ class DTrainer():
self.pred_fake = []
self.pred_real = []
for p in pred:
self.pred_fake.append([tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append([tensor[tensor.shape[0] // 2:] for tensor in p])
self.pred_fake.append(
[tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append(
[tensor[tensor.shape[0] // 2:] for tensor in p])
else:
self.pred_fake = pred[:pred.shape[0] // 2]
self.pred_real = pred[pred.shape[0] // 2:]
......@@ -134,20 +147,28 @@ class DTrainer():
#####gan loss
self.gan_loss_fake = 0
for pred_i in self.pred_fake:
zeros = fluid.layers.fill_constant_batch_size_like(input=pred_i[-1],shape=pred_i[-1].shape,value=0,dtype='float32')
zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
if isinstance(pred_i, list):
pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(-1 * pred_i-1, zeros)
minval = fluid.layers.elementwise_min(-1 * pred_i - 1, zeros)
loss_i = -1 * fluid.layers.reduce_mean(minval)
self.gan_loss_fake += loss_i
self.gan_loss_fake /= len(self.pred_fake)
self.gan_loss_real = 0
for pred_i in self.pred_real:
zeros = fluid.layers.fill_constant_batch_size_like(input=pred_i[-1],shape=pred_i[-1].shape,value=0,dtype='float32')
zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
if isinstance(pred_i, list):
pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(pred_i-1, zeros)
minval = fluid.layers.elementwise_min(pred_i - 1, zeros)
loss_i = -1 * fluid.layers.reduce_mean(minval)
self.gan_loss_real += loss_i
self.gan_loss_real /= len(self.pred_real)
......@@ -188,8 +209,7 @@ class SPADE(object):
'--vgg19_pretrain',
type=str,
default="./VGG19_pretrained",
help="VGG19 pretrained model for vgg loss"
)
help="VGG19 pretrained model for vgg loss")
parser.add_argument(
'--crop_width',
type=int,
......@@ -216,10 +236,7 @@ class SPADE(object):
default=4,
help="num of discriminator layers for SPADE")
parser.add_argument(
'--label_nc',
type=int,
default=36,
help="label numbers of SPADE")
'--label_nc', type=int, default=36, help="label numbers of SPADE")
parser.add_argument(
'--ngf',
type=int,
......@@ -245,7 +262,11 @@ class SPADE(object):
type=float,
default=10,
help="weight term of vgg loss")
parser.add_argument('--no_instance', type=bool, default=False, help="Whether to use instance label.")
parser.add_argument(
'--no_instance',
type=bool,
default=False,
help="Whether to use instance label.")
return parser
......@@ -261,7 +282,9 @@ class SPADE(object):
def build_model(self):
data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width]
label_shape = [-1, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width]
label_shape = [
-1, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width
]
edge_shape = [-1, 1, self.cfg.crop_height, self.cfg.crop_width]
input_A = fluid.layers.data(
......@@ -273,7 +296,8 @@ class SPADE(object):
input_fake = fluid.layers.data(
name='input_fake', shape=data_shape, dtype='float32')
gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg, self.batch_num)
gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg,
self.batch_num)
dis_trainer = DTrainer(input_A, input_B, input_C, input_fake, self.cfg,
self.batch_num)
py_reader = fluid.io.PyReader(
......@@ -281,7 +305,7 @@ class SPADE(object):
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
py_reader.decorate_batch_generator(
py_reader.decorate_batch_generator(
self.train_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
......@@ -290,8 +314,9 @@ class SPADE(object):
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
gen_trainer.vgg.load_vars(exe, gen_trainer.program, self.cfg.vgg19_pretrain)
gen_trainer.vgg.load_vars(exe, gen_trainer.program,
self.cfg.vgg19_pretrain)
if self.cfg.init_model:
utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D")
......@@ -315,7 +340,8 @@ class SPADE(object):
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0]['input_B'], tensor[0]['input_C']
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][
'input_B'], tensor[0]['input_C']
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_C = fluid.LoDTensor()
......@@ -327,25 +353,27 @@ class SPADE(object):
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
gen_trainer_program,
fetch_list=[
gen_trainer.gan_loss, gen_trainer.vgg_loss, gen_trainer.gan_feat_loss,
gen_trainer.fake_B
gen_trainer.gan_loss, gen_trainer.vgg_loss,
gen_trainer.gan_feat_loss, gen_trainer.fake_B
],
feed={"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C})
feed={
"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C
})
# optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
fetch_list=[
dis_trainer.gan_loss_real,
dis_trainer.gan_loss_fake
],
feed={
"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C,
"input_fake": fake_B_tmp
})
d_loss_real, d_loss_fake = exe.run(
dis_trainer_program,
fetch_list=[
dis_trainer.gan_loss_real, dis_trainer.gan_loss_fake
],
feed={
"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C,
"input_fake": fake_B_tmp
})
batch_time = time.time() - s_time
t_time += batch_time
......@@ -355,7 +383,8 @@ class SPADE(object):
d_loss_real: {}; d_loss_fake: {}; \n\
Batch_time_cost: {:.2f}"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_vgg[
0], g_loss_feat[0], d_loss_real[0], d_loss_fake[0], batch_time))
0], g_loss_feat[0], d_loss_real[0], d_loss_fake[
0], batch_time))
sys.stdout.flush()
batch_id += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册