提交 796b59d0 编写于 作者: H hypox64

sngan

上级 48c032b1
......@@ -103,7 +103,7 @@ def define_G(N=2, n_blocks=1, gpu_id='-1'):
return netG
################################Discriminator################################
def define_D(input_nc, ndf, n_layers_D, use_sigmoid=False, num_D=1, gpu_id='-1'):
def define_D(input_nc=6, ndf=64, n_layers_D=3, use_sigmoid=False, num_D=4, gpu_id='-1'):
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, num_D)
if gpu_id != '-1' and len(gpu_id) == 1:
netD.cuda()
......@@ -188,20 +188,19 @@ class GANLoss(nn.Module):
def forward(self, dis_fake = None, dis_real = None):
if isinstance(dis_fake, list):
weight = 2**len(dis_fake)
if self.mode == 'D':
loss = 0
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i],dis_real[i])
loss += self.lossf(dis_fake[i][0],dis_real[i][0])
elif self.mode =='G':
loss = 0
weight = 2**len(dis_fake)
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i])
loss += weight*self.lossf(dis_fake[i][0])
return loss
else:
if self.mode == 'D':
return self.lossf(dis_fake,dis_real)
return self.lossf(dis_fake[0],dis_real[0])
elif self.mode =='G':
return self.lossf(dis_fake)
return self.lossf(dis_fake[0])
......@@ -114,7 +114,7 @@ class HingeLossG(nn.Module):
super(HingeLossG, self).__init__()
def forward(self, dis_fake):
loss_fake = -torch.mean(dis_fake)
loss_fake = F.relu(-torch.mean(dis_fake))
return loss_fake
class VGGLoss(nn.Module):
......@@ -128,7 +128,7 @@ class VGGLoss(nn.Module):
self.vgg = nn.DataParallel(self.vgg)
self.vgg.cuda()
self.criterion = nn.L1Loss()
self.criterion = nn.MSELoss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
......
......@@ -12,10 +12,10 @@ import torch
import torch.nn as nn
import time
from util import mosaic,util,ffmpeg,filt,data,dataloader
from util import util,data,dataloader
from util import image_processing as impro
from models import BVDNet,model_util
import torch.backends.cudnn as cudnn
from skimage.metrics import structural_similarity
from tensorboardX import SummaryWriter
'''
......@@ -44,7 +44,26 @@ opt.parser.add_argument('--continue_train', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='face', help='')
opt.parser.add_argument('--showresult_freq',type=int,default=1000, help='')
opt.parser.add_argument('--showresult_num',type=int,default=4, help='')
opt.parser.add_argument('--psnr_freq',type=int,default=100, help='')
def ImageQualityEvaluation(tensor1,tensor2,showiter,writer,tag):
batch_len = len(tensor1)
psnr,ssmi = 0,0
for i in range(len(tensor1)):
img1,img2 = data.tensor2im(tensor1,rgb2bgr=False,batch_index=i), data.tensor2im(tensor2,rgb2bgr=False,batch_index=i)
psnr += impro.psnr(img1,img2)
ssmi += structural_similarity(img1,img2,multichannel=True)
writer.add_scalars('quality/psnr', {tag:psnr/batch_len}, showiter)
writer.add_scalars('quality/ssmi', {tag:ssmi/batch_len}, showiter)
return psnr/batch_len,ssmi/batch_len
def ShowImage(tensor1,tensor2,tensor3,showiter,max_num,writer,tag):
show_imgs = []
for i in range(max_num):
show_imgs += [ data.tensor2im(tensor1,rgb2bgr = False,batch_index=i),
data.tensor2im(tensor2,rgb2bgr = False,batch_index=i),
data.tensor2im(tensor3,rgb2bgr = False,batch_index=i)]
show_img = impro.splice(show_imgs, (opt.showresult_num,3))
writer.add_image(tag, show_img,showiter,dataformats='HWC')
'''
--------------------------Init--------------------------
......@@ -66,10 +85,15 @@ TBGlobalWriter = SummaryWriter(tensorboard_savedir)
if opt.gpu_id != '-1' and len(opt.gpu_id) == 1:
torch.backends.cudnn.benchmark = True
netG = BVDNet.define_G(opt.N,gpu_id=opt.gpu_id)
netD = BVDNet.define_D(gpu_id=opt.gpu_id)
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossf_L2 = nn.MSELoss()
lossf_VGG = model_util.VGGLoss(opt.gpu_id)
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossfun_L2 = nn.MSELoss()
lossfun_VGG = model_util.VGGLoss(opt.gpu_id)
lossfun_GAND = BVDNet.GANLoss('D')
lossfun_GANG = BVDNet.GANLoss('G')
'''
--------------------------Init DataLoader--------------------------
......@@ -101,35 +125,46 @@ for train_iter in range(Videodataloader_train.n_iter):
previous_frame = data.to_tensor(previous_predframe_tmp, opt.gpu_id)
else:
previous_frame = data.to_tensor(previous_frame, opt.gpu_id)
optimizer_G.zero_grad()
############### Forward ####################
# Fake Generator
out = netG(mosaic_stream,previous_frame)
loss_L2 = lossf_L2(out,ori_stream[:,:,opt.N]) * opt.lambda_L2
loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
loss = loss_L2+loss_VGG
loss.backward()
# Discriminator
dis_real = netD(torch.cat((mosaic_stream[:,:,opt.N],ori_stream[:,:,opt.N].detach()),dim=1))
dis_fake_D = netD(torch.cat((mosaic_stream[:,:,opt.N],out.detach()),dim=1))
loss_D = lossfun_GAND(dis_fake_D,dis_real) * opt.lambda_GAN
# Generator
dis_fake_G = netD(torch.cat((mosaic_stream[:,:,opt.N],out),dim=1))
loss_L2 = lossfun_L2(out,ori_stream[:,:,opt.N]) * opt.lambda_L2
loss_VGG = lossfun_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
loss_GANG = lossfun_GANG(dis_fake_G) * opt.lambda_GAN
loss_G = loss_L2+loss_VGG+loss_GANG
############### Backward Pass ####################
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
previous_predframe_tmp = out.detach().cpu().numpy()
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item(),
'loss_D':loss_D.item(),'loss_G':loss_G.item()}, train_iter)
# save network
if train_iter%opt.save_freq == 0 and train_iter != 0:
model_util.save(netG, os.path.join('checkpoints',opt.savename,str(train_iter)+'.pth'), opt.gpu_id)
# psnr
if train_iter%opt.psnr_freq ==0:
psnr = 0
for i in range(len(out)):
psnr += impro.psnr(data.tensor2im(out,batch_index=i), data.tensor2im(ori_stream[:,:,opt.N],batch_index=i))
TBGlobalWriter.add_scalars('psnr', {'train':psnr/len(out)}, train_iter)
model_util.save(netG, os.path.join('checkpoints',opt.savename,str(train_iter)+'_G.pth'), opt.gpu_id)
model_util.save(netD, os.path.join('checkpoints',opt.savename,str(train_iter)+'_D.pth'), opt.gpu_id)
# Image quality evaluation
if train_iter%(opt.showresult_freq//10) == 0:
ImageQualityEvaluation(out,ori_stream[:,:,opt.N],train_iter,TBGlobalWriter,'train')
# Show result
if train_iter % opt.showresult_freq == 0:
show_imgs = []
for i in range(opt.showresult_num):
show_imgs += [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False,batch_index=i),
data.tensor2im(out,rgb2bgr = False,batch_index=i),
data.tensor2im(ori_stream[:,:,opt.N],rgb2bgr = False,batch_index=i)]
show_img = impro.splice(show_imgs, (opt.showresult_num,3))
TBGlobalWriter.add_image('train', show_img,train_iter,dataformats='HWC')
ShowImage(mosaic_stream[:,:,opt.N],out,ori_stream[:,:,opt.N],train_iter,opt.showresult_num,TBGlobalWriter,'train')
'''
--------------------------Eval--------------------------
......@@ -144,29 +179,21 @@ for train_iter in range(Videodataloader_train.n_iter):
previous_frame = data.to_tensor(previous_frame, opt.gpu_id)
with torch.no_grad():
out = netG(mosaic_stream,previous_frame)
loss_L2 = lossf_L2(out,ori_stream[:,:,opt.N])
loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
TBGlobalWriter.add_scalars('loss/eval', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
loss_L2 = lossfun_L2(out,ori_stream[:,:,opt.N]) * opt.lambda_L2
loss_VGG = lossfun_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
#TBGlobalWriter.add_scalars('loss/eval', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
previous_predframe_tmp = out.detach().cpu().numpy()
#psnr
if (train_iter)%opt.psnr_freq ==0:
psnr = 0
for i in range(len(out)):
psnr += impro.psnr(data.tensor2im(out,batch_index=i), data.tensor2im(ori_stream[:,:,opt.N],batch_index=i))
TBGlobalWriter.add_scalars('psnr', {'eval':psnr/len(out)}, train_iter)
#show
# Image quality evaluation
if train_iter%(opt.showresult_freq//10) == 0:
psnr,ssmi = ImageQualityEvaluation(out,ori_stream[:,:,opt.N],train_iter,TBGlobalWriter,'eval')
# Show result
if train_iter % opt.showresult_freq == 0:
show_imgs = []
for i in range(opt.showresult_num):
show_imgs += [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False,batch_index=i),
data.tensor2im(out,rgb2bgr = False,batch_index=i),
data.tensor2im(ori_stream[:,:,opt.N],rgb2bgr = False,batch_index=i)]
show_img = impro.splice(show_imgs, (opt.showresult_num,3))
TBGlobalWriter.add_image('eval', show_img,train_iter,dataformats='HWC')
ShowImage(mosaic_stream[:,:,opt.N],out,ori_stream[:,:,opt.N],train_iter,opt.showresult_num,TBGlobalWriter,'eval')
t_end = time.time()
print('iter:{0:d} t:{1:.2f} L2:{2:.4f} vgg:{3:.4f} psnr:{4:.2f}'.format(train_iter,t_end-t_start,
loss_L2.item(),loss_VGG.item(),psnr/len(out)) )
print('iter:{0:d} t:{1:.2f} L2:{2:.4f} vgg:{3:.4f} psnr:{4:.2f} ssmi:{5:.3f}'.format(train_iter,t_end-t_start,
loss_L2.item(),loss_VGG.item(),psnr,ssmi) )
t_strat = time.time()
'''
......@@ -179,18 +206,21 @@ for train_iter in range(Videodataloader_train.n_iter):
for video in videos:
frames = os.listdir(os.path.join(opt.dataset_test,video,'image'))
sorted(frames)
mosaic_stream = []
for i in range(opt.T):
_mosaic = impro.imread(os.path.join(opt.dataset_test,video,'image',frames[i*opt.S]),loadsize=opt.finesize,rgb=True)
mosaic_stream.append(_mosaic)
previous = impro.imread(os.path.join(opt.dataset_test,video,'image',frames[opt.N*opt.S-1]),loadsize=opt.finesize,rgb=True)
mosaic_stream = (np.array(mosaic_stream).astype(np.float32)/255.0-0.5)/0.5
mosaic_stream = mosaic_stream.reshape(1,opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
previous = data.im2tensor(previous,bgr2rgb = False, gpu_id = opt.gpu_id,use_transform = False, is0_1 = False)
with torch.no_grad():
out = netG(mosaic_stream,previous)
for step in range(5):
mosaic_stream = []
for i in range(opt.T):
_mosaic = impro.imread(os.path.join(opt.dataset_test,video,'image',frames[i*opt.S+step]),loadsize=opt.finesize,rgb=True)
mosaic_stream.append(_mosaic)
if step == 0:
previous = impro.imread(os.path.join(opt.dataset_test,video,'image',frames[opt.N*opt.S-1]),loadsize=opt.finesize,rgb=True)
previous = data.im2tensor(previous,bgr2rgb = False, gpu_id = opt.gpu_id,use_transform = False, is0_1 = False)
mosaic_stream = (np.array(mosaic_stream).astype(np.float32)/255.0-0.5)/0.5
mosaic_stream = mosaic_stream.reshape(1,opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
with torch.no_grad():
out = netG(mosaic_stream,previous)
previous = out
show_imgs+= [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False),data.tensor2im(out,rgb2bgr = False)]
show_img = impro.splice(show_imgs, (len(videos),2))
TBGlobalWriter.add_image('test', show_img,train_iter,dataformats='HWC')
TBGlobalWriter.add_image('test', show_img,train_iter,dataformats='HWC')
\ No newline at end of file
......@@ -31,7 +31,12 @@ class VideoLoader(object):
#Init load pool
for i in range(self.opt.S*self.opt.T):
#print(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'))
# random
if np.random.random()<0.05:
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
if np.random.random()<0.02:
self.transform_params['rate']['crop'] = [np.random.random(),np.random.random()]
_ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册