提交 ad1a917f 编写于 作者: L lvmengsi 提交者: ruri

fix spade typo (#3365)

上级 c82ab203
...@@ -23,6 +23,8 @@ import time ...@@ -23,6 +23,8 @@ import time
import network.vgg as vgg import network.vgg as vgg
import pickle as pkl import pickle as pkl
import numpy as np import numpy as np
class GTrainer(): class GTrainer():
def __init__(self, input_label, input_img, input_ins, cfg, step_per_epoch): def __init__(self, input_label, input_img, input_ins, cfg, step_per_epoch):
self.cfg = cfg self.cfg = cfg
...@@ -43,8 +45,10 @@ class GTrainer(): ...@@ -43,8 +45,10 @@ class GTrainer():
self.pred_fake = [] self.pred_fake = []
self.pred_real = [] self.pred_real = []
for p in pred: for p in pred:
self.pred_fake.append([tensor[:tensor.shape[0] // 2] for tensor in p]) self.pred_fake.append(
self.pred_real.append([tensor[tensor.shape[0] // 2:] for tensor in p]) [tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append(
[tensor[tensor.shape[0] // 2:] for tensor in p])
else: else:
self.pred_fake = pred[:pred.shape[0] // 2] self.pred_fake = pred[:pred.shape[0] // 2]
self.pred_real = pred[pred.shape[0] // 2:] self.pred_real = pred[pred.shape[0] // 2:]
...@@ -67,20 +71,26 @@ class GTrainer(): ...@@ -67,20 +71,26 @@ class GTrainer():
for i in range(num_D): for i in range(num_D):
num_intermediate_outputs = len(self.pred_fake[i]) - 1 num_intermediate_outputs = len(self.pred_fake[i]) - 1
for j in range(num_intermediate_outputs): for j in range(num_intermediate_outputs):
self.gan_feat_loss = fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub( self.gan_feat_loss = fluid.layers.reduce_mean(
x=self.pred_fake[i][j], y=self.pred_real[i][j]))) * cfg.lambda_feat / num_D fluid.layers.abs(
fluid.layers.elementwise_sub(
x=self.pred_fake[i][j], y=self.pred_real[i][
j]))) * cfg.lambda_feat / num_D
self.gan_feat_loss.persistable = True self.gan_feat_loss.persistable = True
########VGG Feat loss ########VGG Feat loss
weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
self.vgg = vgg.VGG19() self.vgg = vgg.VGG19()
fake_vgg = self.vgg.net(self.fake_B) fake_vgg = self.vgg.net(self.fake_B)
real_vgg = self.vgg.net(input_img) real_vgg = self.vgg.net(input_img)
self.vgg_loss = 0.0 self.vgg_loss = 0.0
for i in range(len(fake_vgg)): for i in range(len(fake_vgg)):
self.vgg_loss += weights[i] * fluid.layers.reduce_mean(fluid.layers.abs(fluid.layers.elementwise_sub( self.vgg_loss += weights[i] * fluid.layers.reduce_mean(
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=fake_vgg[i], y=real_vgg[i]))) x=fake_vgg[i], y=real_vgg[i])))
self.vgg_loss.persistable = True self.vgg_loss.persistable = True
self.g_loss = (self.gan_loss + self.gan_feat_loss + self.vgg_loss)/3 self.g_loss = (
self.gan_loss + self.gan_feat_loss + self.vgg_loss) / 3
lr = cfg.learning_rate lr = cfg.learning_rate
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
...@@ -109,7 +119,8 @@ class GTrainer(): ...@@ -109,7 +119,8 @@ class GTrainer():
class DTrainer(): class DTrainer():
def __init__(self, input_label, input_img, input_ins, fake_B, cfg, step_per_epoch): def __init__(self, input_label, input_img, input_ins, fake_B, cfg,
step_per_epoch):
self.program = fluid.default_main_program().clone() self.program = fluid.default_main_program().clone()
lr = cfg.learning_rate lr = cfg.learning_rate
with fluid.program_guard(self.program): with fluid.program_guard(self.program):
...@@ -125,8 +136,10 @@ class DTrainer(): ...@@ -125,8 +136,10 @@ class DTrainer():
self.pred_fake = [] self.pred_fake = []
self.pred_real = [] self.pred_real = []
for p in pred: for p in pred:
self.pred_fake.append([tensor[:tensor.shape[0] // 2] for tensor in p]) self.pred_fake.append(
self.pred_real.append([tensor[tensor.shape[0] // 2:] for tensor in p]) [tensor[:tensor.shape[0] // 2] for tensor in p])
self.pred_real.append(
[tensor[tensor.shape[0] // 2:] for tensor in p])
else: else:
self.pred_fake = pred[:pred.shape[0] // 2] self.pred_fake = pred[:pred.shape[0] // 2]
self.pred_real = pred[pred.shape[0] // 2:] self.pred_real = pred[pred.shape[0] // 2:]
...@@ -134,20 +147,28 @@ class DTrainer(): ...@@ -134,20 +147,28 @@ class DTrainer():
#####gan loss #####gan loss
self.gan_loss_fake = 0 self.gan_loss_fake = 0
for pred_i in self.pred_fake: for pred_i in self.pred_fake:
zeros = fluid.layers.fill_constant_batch_size_like(input=pred_i[-1],shape=pred_i[-1].shape,value=0,dtype='float32') zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
if isinstance(pred_i, list): if isinstance(pred_i, list):
pred_i = pred_i[-1] pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(-1 * pred_i-1, zeros) minval = fluid.layers.elementwise_min(-1 * pred_i - 1, zeros)
loss_i = -1 * fluid.layers.reduce_mean(minval) loss_i = -1 * fluid.layers.reduce_mean(minval)
self.gan_loss_fake += loss_i self.gan_loss_fake += loss_i
self.gan_loss_fake /= len(self.pred_fake) self.gan_loss_fake /= len(self.pred_fake)
self.gan_loss_real = 0 self.gan_loss_real = 0
for pred_i in self.pred_real: for pred_i in self.pred_real:
zeros = fluid.layers.fill_constant_batch_size_like(input=pred_i[-1],shape=pred_i[-1].shape,value=0,dtype='float32') zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
if isinstance(pred_i, list): if isinstance(pred_i, list):
pred_i = pred_i[-1] pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(pred_i-1, zeros) minval = fluid.layers.elementwise_min(pred_i - 1, zeros)
loss_i = -1 * fluid.layers.reduce_mean(minval) loss_i = -1 * fluid.layers.reduce_mean(minval)
self.gan_loss_real += loss_i self.gan_loss_real += loss_i
self.gan_loss_real /= len(self.pred_real) self.gan_loss_real /= len(self.pred_real)
...@@ -188,8 +209,7 @@ class SPADE(object): ...@@ -188,8 +209,7 @@ class SPADE(object):
'--vgg19_pretrain', '--vgg19_pretrain',
type=str, type=str,
default="./VGG19_pretrained", default="./VGG19_pretrained",
help="VGG19 pretrained model for vgg loss" help="VGG19 pretrained model for vgg loss")
)
parser.add_argument( parser.add_argument(
'--crop_width', '--crop_width',
type=int, type=int,
...@@ -216,10 +236,7 @@ class SPADE(object): ...@@ -216,10 +236,7 @@ class SPADE(object):
default=4, default=4,
help="num of discriminator layers for SPADE") help="num of discriminator layers for SPADE")
parser.add_argument( parser.add_argument(
'--label_nc', '--label_nc', type=int, default=36, help="label numbers of SPADE")
type=int,
default=36,
help="label numbers of SPADE")
parser.add_argument( parser.add_argument(
'--ngf', '--ngf',
type=int, type=int,
...@@ -245,7 +262,11 @@ class SPADE(object): ...@@ -245,7 +262,11 @@ class SPADE(object):
type=float, type=float,
default=10, default=10,
help="weight term of vgg loss") help="weight term of vgg loss")
parser.add_argument('--no_instance', type=bool, default=False, help="Whether to use instance label.") parser.add_argument(
'--no_instance',
type=bool,
default=False,
help="Whether to use instance label.")
return parser return parser
...@@ -261,7 +282,9 @@ class SPADE(object): ...@@ -261,7 +282,9 @@ class SPADE(object):
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width] data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width]
label_shape = [-1, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width] label_shape = [
-1, self.cfg.label_nc, self.cfg.crop_height, self.cfg.crop_width
]
edge_shape = [-1, 1, self.cfg.crop_height, self.cfg.crop_width] edge_shape = [-1, 1, self.cfg.crop_height, self.cfg.crop_width]
input_A = fluid.layers.data( input_A = fluid.layers.data(
...@@ -273,7 +296,8 @@ class SPADE(object): ...@@ -273,7 +296,8 @@ class SPADE(object):
input_fake = fluid.layers.data( input_fake = fluid.layers.data(
name='input_fake', shape=data_shape, dtype='float32') name='input_fake', shape=data_shape, dtype='float32')
gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg, self.batch_num) gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg,
self.batch_num)
dis_trainer = DTrainer(input_A, input_B, input_C, input_fake, self.cfg, dis_trainer = DTrainer(input_A, input_B, input_C, input_fake, self.cfg,
self.batch_num) self.batch_num)
py_reader = fluid.io.PyReader( py_reader = fluid.io.PyReader(
...@@ -290,7 +314,8 @@ class SPADE(object): ...@@ -290,7 +314,8 @@ class SPADE(object):
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
gen_trainer.vgg.load_vars(exe, gen_trainer.program, self.cfg.vgg19_pretrain) gen_trainer.vgg.load_vars(exe, gen_trainer.program,
self.cfg.vgg19_pretrain)
if self.cfg.init_model: if self.cfg.init_model:
utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G") utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
...@@ -315,7 +340,8 @@ class SPADE(object): ...@@ -315,7 +340,8 @@ class SPADE(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for tensor in py_reader(): for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0]['input_B'], tensor[0]['input_C'] data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][
'input_B'], tensor[0]['input_C']
tensor_A = fluid.LoDTensor() tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor() tensor_B = fluid.LoDTensor()
tensor_C = fluid.LoDTensor() tensor_C = fluid.LoDTensor()
...@@ -327,18 +353,20 @@ class SPADE(object): ...@@ -327,18 +353,20 @@ class SPADE(object):
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run( g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
gen_trainer_program, gen_trainer_program,
fetch_list=[ fetch_list=[
gen_trainer.gan_loss, gen_trainer.vgg_loss, gen_trainer.gan_feat_loss, gen_trainer.gan_loss, gen_trainer.vgg_loss,
gen_trainer.fake_B gen_trainer.gan_feat_loss, gen_trainer.fake_B
], ],
feed={"input_label": tensor_A, feed={
"input_label": tensor_A,
"input_img": tensor_B, "input_img": tensor_B,
"input_ins": tensor_C}) "input_ins": tensor_C
})
# optimize the discriminator network # optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program, d_loss_real, d_loss_fake = exe.run(
dis_trainer_program,
fetch_list=[ fetch_list=[
dis_trainer.gan_loss_real, dis_trainer.gan_loss_real, dis_trainer.gan_loss_fake
dis_trainer.gan_loss_fake
], ],
feed={ feed={
"input_label": tensor_A, "input_label": tensor_A,
...@@ -355,7 +383,8 @@ class SPADE(object): ...@@ -355,7 +383,8 @@ class SPADE(object):
d_loss_real: {}; d_loss_fake: {}; \n\ d_loss_real: {}; d_loss_fake: {}; \n\
Batch_time_cost: {:.2f}" Batch_time_cost: {:.2f}"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_vgg[ .format(epoch_id, batch_id, g_loss_gan[0], g_loss_vgg[
0], g_loss_feat[0], d_loss_real[0], d_loss_fake[0], batch_time)) 0], g_loss_feat[0], d_loss_real[0], d_loss_fake[
0], batch_time))
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册