From 972dcc875563f94628478a09b60396d4adc1ad76 Mon Sep 17 00:00:00 2001 From: hypox64 Date: Sun, 18 Apr 2021 12:53:43 +0800 Subject: [PATCH] BVDNet --- cores/options.py | 18 +- make_datasets/make_pix2pix_dataset.py | 2 +- models/BVDNet.py | 144 ++++++ models/model_util.py | 10 + train/add/train.py | 2 +- train/clean/train.py | 627 +++++++++++++------------- train/clean/train_old.py | 310 +++++++++++++ util/data.py | 204 ++++++--- util/image_processing.py | 30 +- util/util.py | 5 + 10 files changed, 964 insertions(+), 388 deletions(-) create mode 100644 models/BVDNet.py create mode 100644 models/model_util.py create mode 100644 train/clean/train_old.py diff --git a/cores/options.py b/cores/options.py index 7e8c165..6c972b5 100644 --- a/cores/options.py +++ b/cores/options.py @@ -11,7 +11,7 @@ class Options(): def initialize(self): #base - self.parser.add_argument('--use_gpu', type=int,default=0, help='if -1, use cpu') + self.parser.add_argument('--use_gpu', type=str,default='0', help='if -1, use cpu') self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path') self.parser.add_argument('-ss', '--start_time', type=str, default='00:00:00',help='start position of video, default is the beginning of video') self.parser.add_argument('-t', '--last_time', type=str, default='00:00:00',help='duration of the video, default is the entire video') @@ -58,13 +58,15 @@ class Options(): model_name = os.path.basename(self.opt.model_path) self.opt.temp_dir = os.path.join(self.opt.temp_dir, 'DeepMosaics_temp') - - os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu) - import torch - if torch.cuda.is_available() and self.opt.use_gpu > -1: - pass - else: - self.opt.use_gpu = -1 + + + if self.opt.use_gpu != '-1': + os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu) + import torch + if not torch.cuda.is_available(): + self.opt.use_gpu = '-1' + # else: + # self.opt.use_gpu = '-1' if test_flag: if not os.path.exists(self.opt.media_path): diff --git a/make_datasets/make_pix2pix_dataset.py b/make_datasets/make_pix2pix_dataset.py index c9ccb87..4256f6c 100644 --- a/make_datasets/make_pix2pix_dataset.py +++ b/make_datasets/make_pix2pix_dataset.py @@ -87,7 +87,7 @@ for fold in range(opt.fold): mask = mask_drawn if 'irregular' in opt.mod: mask_irr = impro.imread(irrpaths[random.randint(0,12000-1)],'gray') - mask_irr = data.random_transform_single(mask_irr, (img.shape[0],img.shape[1])) + mask_irr = data.random_transform_single_mask(mask_irr, (img.shape[0],img.shape[1])) mask = mask_irr if 'network' in opt.mod: mask_net = runmodel.get_ROI_position(img,net,opt,keepsize=True)[0] diff --git a/models/BVDNet.py b/models/BVDNet.py new file mode 100644 index 0000000..eb26f04 --- /dev/null +++ b/models/BVDNet.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .pix2pixHD_model import * + + +class Encoder2d(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm2d): + super(Encoder2d, self).__init__() + activation = nn.ReLU(True) + + model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.ReflectionPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0), + norm_layer(ngf * mult * 2), activation] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +class Encoder3d(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm3d): + super(Encoder3d, self).__init__() + activation = nn.ReLU(True) + + model = [nn.Conv3d(input_nc, ngf, kernel_size=3, padding=1), norm_layer(ngf), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), activation] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +class BVDNet(nn.Module): + def __init__(self, N, n_downsampling=3, n_blocks=1, input_nc=3, output_nc=3): + super(BVDNet, self).__init__() + + ngf = 64 + padding_type = 'reflect' + norm_layer = nn.BatchNorm2d + self.N = N + + # encoder + self.encoder3d = Encoder3d(input_nc,64,n_downsampling) + self.encoder2d = Encoder2d(input_nc,64,n_downsampling) + + ### resnet blocks + self.blocks = [] + mult = 2**n_downsampling + for i in range(n_blocks): + self.blocks += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=nn.ReLU(True), norm_layer=norm_layer)] + self.blocks = nn.Sequential(*self.blocks) + + ### decoder + self.decoder = [] + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + # self.decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + # norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] + self.decoder += [ nn.Upsample(scale_factor = 2, mode='nearest'), + nn.ReflectionPad2d(1), + nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + self.decoder += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + self.decoder = nn.Sequential(*self.decoder) + self.limiter = nn.Tanh() + + def forward(self, stream, last): + this_shortcut = stream[:,:,self.N] + stream = self.encoder3d(stream) + stream = stream.reshape(stream.size(0),stream.size(1),stream.size(3),stream.size(4)) + # print(stream.shape) + last = self.encoder2d(last) + x = stream + last + x = self.blocks(x) + x = self.decoder(x) + x = x+this_shortcut + x = self.limiter(x) + #print(x.shape) + + # print(stream.shape,last.shape) + return x + +class VGGLoss(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss, self).__init__() + + self.vgg = Vgg19() + if gpu_ids != '-1' and len(gpu_ids) == 1: + self.vgg.cuda() + elif gpu_ids != '-1' and len(gpu_ids) > 1: + self.vgg = nn.DataParallel(self.vgg) + self.vgg.cuda() + + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + +from torchvision import models +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out diff --git a/models/model_util.py b/models/model_util.py new file mode 100644 index 0000000..8f1789f --- /dev/null +++ b/models/model_util.py @@ -0,0 +1,10 @@ +import torch +import torch.nn as nn + +def save(net,path,gpu_id): + if isinstance(net, nn.DataParallel): + torch.save(net.module.cpu().state_dict(),path) + else: + torch.save(net.cpu().state_dict(),path) + if gpu_id != '-1': + net.cuda() \ No newline at end of file diff --git a/train/add/train.py b/train/add/train.py index d64063e..8ccc6ee 100644 --- a/train/add/train.py +++ b/train/add/train.py @@ -68,7 +68,7 @@ def loadimage(imagepaths,maskpaths,opt,test_flag = False): for i in range(len(imagepaths)): img = impro.resize(impro.imread(imagepaths[i]),opt.loadsize) mask = impro.resize(impro.imread(maskpaths[i],mod = 'gray'),opt.loadsize) - img,mask = data.random_transform_image(img, mask, opt.finesize, test_flag) + img,mask = data.random_transform_pair_image(img, mask, opt.finesize, test_flag) images[i] = (img.transpose((2, 0, 1))/255.0) masks[i] = (mask.reshape(1,1,opt.finesize,opt.finesize)/255.0) images = Totensor(images,opt.use_gpu) diff --git a/train/clean/train.py b/train/clean/train.py index 70efb41..49c886f 100644 --- a/train/clean/train.py +++ b/train/clean/train.py @@ -1,310 +1,317 @@ -import os -import sys -sys.path.append("..") -sys.path.append("../..") -from cores import Options -opt = Options() - -import numpy as np -import cv2 -import random -import torch -import torch.nn as nn -import time -from multiprocessing import Process, Queue - -from util import mosaic,util,ffmpeg,filt,data -from util import image_processing as impro -from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model -import matplotlib -matplotlib.use('Agg') -from matplotlib import pyplot as plt -import torch.backends.cudnn as cudnn - -''' ---------------------------Get options-------------------------- -''' -opt.parser.add_argument('--N',type=int,default=25, help='') -opt.parser.add_argument('--lr',type=float,default=0.0002, help='') -opt.parser.add_argument('--beta1',type=float,default=0.5, help='') -opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan') -opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss') -opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model') -opt.parser.add_argument('--lambda_L1',type=float,default=100, help='') -opt.parser.add_argument('--lambda_gan',type=float,default=1, help='') -opt.parser.add_argument('--finesize',type=int,default=256, help='') -opt.parser.add_argument('--loadsize',type=int,default=286, help='') -opt.parser.add_argument('--batchsize',type=int,default=1, help='') -opt.parser.add_argument('--norm',type=str,default='instance', help='') -opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') -opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') -opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') -opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool') -opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data') - -opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') -opt.parser.add_argument('--maxiter',type=int,default=10000000, help='') -opt.parser.add_argument('--savefreq',type=int,default=10000, help='') -opt.parser.add_argument('--startiter',type=int,default=0, help='') -opt.parser.add_argument('--continue_train', action='store_true', help='') -opt.parser.add_argument('--savename',type=str,default='face', help='') - - -''' ---------------------------Init-------------------------- -''' -opt = opt.getparse() -dir_checkpoint = os.path.join('checkpoints/',opt.savename) -util.makedirs(dir_checkpoint) -util.writelog(os.path.join(dir_checkpoint,'loss.txt'), - str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt)) -cudnn.benchmark = True - -N = opt.N -loss_sum = [0.,0.,0.,0.,0.,0] -loss_plot = [[],[],[],[]] -item_plot = [] - -# list video dir -videonames = os.listdir(opt.dataset) -videonames.sort() -lengths = [];tmp = [] -print('Check dataset...') -for video in videonames: - if video != 'opt.txt': - video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image')) - lengths.append(len(video_images)) - tmp.append(video) -videonames = tmp -video_num = len(videonames) - -#--------------------------Init network-------------------------- -print('Init network...') -if opt.hd: - netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm) -else: - netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm) -netG.cuda() -loadmodel.show_paramsnumber(netG,'netG') - -if opt.gan: - if opt.hd: - netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True) - else: - netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm) - netD.cuda() - netD.train() - -#--------------------------continue train-------------------------- -if opt.continue_train: - if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')): - opt.continue_train = False - print('can not load last_G, training on init weight.') -if opt.continue_train: - netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) - if opt.gan: - netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) - f = open(os.path.join(dir_checkpoint,'iter'),'r') - opt.startiter = int(f.read()) - f.close() - -#--------------------------optimizer & loss-------------------------- -optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) -criterion_L1 = nn.L1Loss() -criterion_L2 = nn.MSELoss() -if opt.gan: - optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) - if opt.hd: - criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() - criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt) - criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu]) - else: - criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda() - -''' ---------------------------preload data & data pool-------------------------- -''' -print('Preloading data, please wait...') -def preload(pool): - cnt = 0 - input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize) - ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize) - while 1: - try: - for i in range(opt.batchsize): - video_index = random.randint(0,video_num-1) - videoname = videonames[video_index] - img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) - input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt) - cnt += 1 - pool.put([input_imgs,ground_trues]) - except Exception as e: - print("Error:",videoname,e) -pool = Queue(opt.image_pool) -for i in range(opt.load_process): - p = Process(target=preload,args=(pool,)) - p.daemon = True - p.start() - -''' ---------------------------train-------------------------- -''' -util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py')) -util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py')) -netG.train() -time_start=time.time() -print("Begin training...") -for iter in range(opt.startiter+1,opt.maxiter): - - inputdata,target = pool.get() - inputdata,target = inputdata.cuda(),target.cuda() - - if opt.gan: - # compute fake images: G(A) - pred = netG(inputdata) - real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] - - # --------------------update D-------------------- - pix2pix_model.set_requires_grad(netD,True) - optimizer_D.zero_grad() - # Fake - fake_AB = torch.cat((real_A, pred), 1) - pred_fake = netD(fake_AB.detach()) - loss_D_fake = criterionGAN(pred_fake, False) - # Real - real_AB = torch.cat((real_A, target), 1) - pred_real = netD(real_AB) - loss_D_real = criterionGAN(pred_real, True) - # combine loss and calculate gradients - loss_D = (loss_D_fake + loss_D_real) * 0.5 - loss_sum[4] += loss_D_fake.item() - loss_sum[5] += loss_D_real.item() - # udpate D's weights - loss_D.backward() - optimizer_D.step() - - # --------------------update G-------------------- - pix2pix_model.set_requires_grad(netD,False) - optimizer_G.zero_grad() - - # First, G(A) should fake the discriminator - fake_AB = torch.cat((real_A, pred), 1) - pred_fake = netD(fake_AB) - loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan - - # combine loss and calculate gradients - if opt.l2: - loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1 - else: - loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1 - - if opt.hd: - real_AB = torch.cat((real_A, target), 1) - pred_real = netD(real_AB) - loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real) - loss_VGG = criterionVGG(pred, target) * opt.lambda_feat - loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG - else: - loss_G = loss_G_GAN + loss_G_L1 - loss_sum[0] += loss_G_L1.item() - loss_sum[1] += loss_G_GAN.item() - loss_sum[2] += loss_G_GAN_Feat.item() - loss_sum[3] += loss_VGG.item() - - # udpate G's weights - loss_G.backward() - optimizer_G.step() - - else: - pred = netG(inputdata) - if opt.l2: - loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1 - else: - loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1 - loss_sum[0] += loss_G_L1.item() - - optimizer_G.zero_grad() - loss_G_L1.backward() - optimizer_G.step() - - # save train result - if (iter+1)%1000 == 0: - try: - data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], - target, pred, os.path.join(dir_checkpoint,'result_train.jpg')) - except Exception as e: - print(e) - - # plot - if (iter+1)%1000 == 0: - time_end = time.time() - #if opt.gan: - savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format( - iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000) - util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True) - if (iter+1)/1000 >= 10: - for i in range(4):loss_plot[i].append(loss_sum[i]/1000) - item_plot.append(iter+1) - try: - labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss'] - for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i]) - plt.xlabel('iter') - plt.legend(loc=1) - plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) - plt.close() - except Exception as e: - print("error:",e) - - loss_sum = [0.,0.,0.,0.,0.,0.] - time_start=time.time() - - # save network - if (iter+1)%(opt.savefreq//10) == 0: - torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth')) - if opt.gan: - torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth')) - if opt.use_gpu !=-1 : - netG.cuda() - if opt.gan: - netD.cuda() - f = open(os.path.join(dir_checkpoint,'iter'),'w+') - f.write(str(iter+1)) - f.close() - - if (iter+1)%opt.savefreq == 0: - os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth')) - if opt.gan: - os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth')) - print('network saved.') - - #test - if (iter+1)%opt.savefreq == 0: - if os.path.isdir('./test'): - netG.eval() - - test_names = os.listdir('./test') - test_names.sort() - result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8') - - for cnt,test_name in enumerate(test_names,0): - img_names = os.listdir(os.path.join('./test',test_name,'image')) - img_names.sort() - inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8') - for i in range(0,N): - img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) - img = impro.resize(img,opt.finesize) - inputdata[:,:,i*3:(i+1)*3] = img - - mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') - mask = impro.resize(mask,opt.finesize) - mask = impro.mask_threshold(mask,15,128) - inputdata[:,:,-1] = mask - result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] - inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) - pred = netG(inputdata) - - pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False) - result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred - - cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result) - netG.train() \ No newline at end of file +import os +import sys +sys.path.append("..") +sys.path.append("../..") +from cores import Options +opt = Options() + +import numpy as np +import cv2 +import random +import torch +import torch.nn as nn +import time +from multiprocessing import Process, Queue + +from util import mosaic,util,ffmpeg,filt,data +from util import image_processing as impro +from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model,BVDNet,model_util +import torch.backends.cudnn as cudnn +from tensorboardX import SummaryWriter + +''' +--------------------------Get options-------------------------- +''' +opt.parser.add_argument('--N',type=int,default=2, help='The input tensor shape is H×W×T×C, T = 2N+1') +opt.parser.add_argument('--S',type=int,default=3, help='Stride of 3 frames') +# opt.parser.add_argument('--T',type=int,default=7, help='T = 2N+1') +opt.parser.add_argument('--M',type=int,default=100, help='How many frames read from each videos') +opt.parser.add_argument('--lr',type=float,default=0.001, help='') +opt.parser.add_argument('--beta1',type=float,default=0.9, help='') +opt.parser.add_argument('--beta2',type=float,default=0.999, help='') +opt.parser.add_argument('--finesize',type=int,default=256, help='') +opt.parser.add_argument('--loadsize',type=int,default=286, help='') +opt.parser.add_argument('--batchsize',type=int,default=1, help='') +opt.parser.add_argument('--lambda_VGG',type=float,default=0.1, help='') +opt.parser.add_argument('--load_thread',type=int,default=4, help='number of thread for loading data') + +opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') +opt.parser.add_argument('--n_epoch',type=int,default=200, help='') +opt.parser.add_argument('--save_freq',type=int,default=100000, help='') +opt.parser.add_argument('--continue_train', action='store_true', help='') +opt.parser.add_argument('--savename',type=str,default='face', help='') +opt.parser.add_argument('--showresult_freq',type=int,default=1000, help='') +opt.parser.add_argument('--showresult_num',type=int,default=4, help='') +opt.parser.add_argument('--psnr_freq',type=int,default=100, help='') + +class TrainVideoLoader(object): + """docstring for VideoLoader + 1.Init TrainVideoLoader as loader + 2.Get data by loader.ori_stream + 3.loader.next() + """ + def __init__(self, opt, video_dir, test_flag=False): + super(TrainVideoLoader, self).__init__() + self.opt = opt + self.test_flag = test_flag + self.video_dir = video_dir + self.t = 0 + self.n_iter = self.opt.M -self.opt.S*(self.opt.T+1) + self.transform_params = data.get_transform_params() + self.ori_load_pool = [] + self.mosaic_load_pool = [] + self.last_pred = None + feg_ori = impro.imread(os.path.join(video_dir,'origin_image','00001.jpg'),loadsize=self.opt.loadsize,rgb=True) + feg_mask = impro.imread(os.path.join(video_dir,'mask','00001.png'),mod='gray',loadsize=self.opt.loadsize) + self.mosaic_size,self.mod,self.rect_rat,self.feather = mosaic.get_random_parameter(feg_ori,feg_mask) + self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)] + + #Init load pool + for i in range(self.opt.S*self.opt.T): + #print(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg')) + _ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.opt.loadsize,rgb=True) + _mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.opt.loadsize) + _mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos) + # _ori_img = data.random_transform_single_image(_ori_img, opt.finesize,self.transform_params,self.test_flag) + # _mosaic_img = data.random_transform_single_image(_mosaic_img, opt.finesize,self.transform_params,self.test_flag) + self.ori_load_pool.append(self.normalize(_ori_img)) + self.mosaic_load_pool.append(self.normalize(_mosaic_img)) + self.ori_load_pool = np.array(self.ori_load_pool) + self.mosaic_load_pool = np.array(self.mosaic_load_pool) + + #Init frist stream + self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy() + self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy() + # stream B,T,H,W,C -> B,C,T,H,W + self.ori_stream = self.ori_stream.reshape (1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3)) + self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3)) + + #Init frist previous frame + self.last_pred = self.ori_load_pool[self.opt.S*self.opt.N-1].copy() + # previous B,C,H,W + self.last_pred = self.last_pred.reshape(1,opt.finesize,opt.finesize,3).transpose((0,3,1,2)) + + def normalize(self,data): + return (data.astype(np.float32)/255.0-0.5)/0.5 + + def next(self): + if self.t != 0: + self.last_pred = None + self.ori_load_pool [:self.opt.S*self.opt.T-1] = self.ori_load_pool [1:self.opt.S*self.opt.T] + self.mosaic_load_pool[:self.opt.S*self.opt.T-1] = self.mosaic_load_pool[1:self.opt.S*self.opt.T] + #print(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg')) + _ori_img = impro.imread(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'),loadsize=self.opt.loadsize,rgb=True) + _mask = impro.imread(os.path.join(self.video_dir,'mask','%05d' % (self.opt.S*self.opt.T+self.t)+'.png' ),mod='gray',loadsize=self.opt.loadsize) + _mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos) + # if np.random.random() < 0.01: + # print('1') + # cv2.imwrite(util.randomstr(10)+'.jpg', _ori_img) + + # _ori_img = data.random_transform_single_image(_ori_img, opt.finesize,self.transform_params,self.test_flag) + # _mosaic_img = data.random_transform_single_image(_mosaic_img, opt.finesize,self.transform_params,self.test_flag) + _ori_img,_mosaic_img = self.normalize(_ori_img),self.normalize(_mosaic_img) + self.ori_load_pool [self.opt.S*self.opt.T-1] = _ori_img + self.mosaic_load_pool[self.opt.S*self.opt.T-1] = _mosaic_img + + self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy() + self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy() + + if np.random.random() < 0.01: + # print(self.ori_stream[0,0].shape) + print('1') + cv2.imwrite(util.randomstr(10)+'.jpg', self.ori_stream[0]) + + # stream B,T,H,W,C -> B,C,T,H,W + self.ori_stream = self.ori_stream.reshape (1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3)) + self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3)) + + self.t += 1 + +class DataLoader(object): + """DataLoader""" + def __init__(self, opt, videolist, test_flag=False): + super(DataLoader, self).__init__() + self.videolist = [] + self.opt = opt + self.test_flag = test_flag + for i in range(self.opt.n_epoch): + self.videolist += videolist + random.shuffle(self.videolist) + self.each_video_n_iter = self.opt.M -self.opt.S*(self.opt.T+1) + self.n_iter = len(self.videolist)//self.opt.load_thread//self.opt.batchsize*self.each_video_n_iter*self.opt.load_thread + self.queue = Queue(self.opt.load_thread) + self.ori_stream = np.zeros((self.opt.batchsize,3,self.opt.T,self.opt.finesize,self.opt.finesize),dtype=np.float32)# B,C,T,H,W + self.mosaic_stream = self.ori_stream.copy() + self.last_pred = np.zeros((self.opt.batchsize,3,self.opt.finesize,self.opt.finesize),dtype=np.float32) + + def load(self,videolist): + for load_video_iter in range(len(videolist)//self.opt.batchsize): + iter_videolist = videolist[load_video_iter*self.opt.batchsize:(load_video_iter+1)*self.opt.batchsize] + videoloaders = [TrainVideoLoader(self.opt,os.path.join(self.opt.dataset,iter_videolist[i]),self.test_flag) for i in range(self.opt.batchsize)] + for each_video_iter in range(self.each_video_n_iter): + for i in range(self.opt.batchsize): + self.ori_stream[i] = videoloaders[i].ori_stream + self.mosaic_stream[i] = videoloaders[i].mosaic_stream + if each_video_iter == 0: + self.last_pred[i] = videoloaders[i].last_pred + videoloaders[i].next() + if each_video_iter == 0: + self.queue.put([self.ori_stream,self.mosaic_stream,self.last_pred]) + else: + self.queue.put([self.ori_stream,self.mosaic_stream,None]) + + def load_init(self): + ptvn = len(self.videolist)//self.opt.load_thread #pre_thread_video_num + for i in range(self.opt.load_thread): + p = Process(target=self.load,args=(self.videolist[i*ptvn:(i+1)*ptvn],)) + p.daemon = True + p.start() + + def get_data(self): + return self.queue.get() + +''' +--------------------------Init-------------------------- +''' +opt = opt.getparse() +opt.T = 2*opt.N+1 +if opt.showresult_num >opt.batchsize: + opt.showresult_num = opt.batchsize +dir_checkpoint = os.path.join('checkpoints',opt.savename) +util.makedirs(dir_checkpoint) +# start tensorboard +localtime = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) +tensorboard_savedir = os.path.join('checkpoints/tensorboard',localtime+'_'+opt.savename) +TBGlobalWriter = SummaryWriter(tensorboard_savedir) +net = BVDNet.BVDNet(opt.N) + + +if opt.use_gpu != '-1' and len(opt.use_gpu) == 1: + torch.backends.cudnn.benchmark = True + net.cuda() +elif opt.use_gpu != '-1' and len(opt.use_gpu) > 1: + torch.backends.cudnn.benchmark = True + net = nn.DataParallel(net) + net.cuda() + +optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) +lossf_L1 = nn.L1Loss() +lossf_VGG = BVDNet.VGGLoss([opt.use_gpu]) + +videolist_tmp = os.listdir(opt.dataset) +videolist = [] +for video in videolist_tmp: + if os.path.isdir(os.path.join(opt.dataset,video)): + if len(os.listdir(os.path.join(opt.dataset,video,'mask')))>=opt.M: + videolist.append(video) +sorted(videolist) +videolist_train = videolist[:int(len(videolist)*0.8)].copy() +videolist_eval = videolist[int(len(videolist)*0.8):].copy() + +dataloader_train = DataLoader(opt, videolist_train) +dataloader_train.load_init() +dataloader_eval = DataLoader(opt, videolist_eval) +dataloader_eval.load_init() + +previous_predframe_train = 0 +previous_predframe_eval = 0 +for train_iter in range(dataloader_train.n_iter): + t_start = time.time() + # train + ori_stream,mosaic_stream,last_frame = dataloader_train.get_data() + ori_stream = data.to_tensor(ori_stream, opt.use_gpu) + mosaic_stream = data.to_tensor(mosaic_stream, opt.use_gpu) + if last_frame is None: + last_frame = data.to_tensor(previous_predframe_train, opt.use_gpu) + else: + last_frame = data.to_tensor(last_frame, opt.use_gpu) + optimizer.zero_grad() + out = net(mosaic_stream,last_frame) + loss_L1 = lossf_L1(out,ori_stream[:,:,opt.N]) + loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG + TBGlobalWriter.add_scalars('loss/train', {'L1':loss_L1.item(),'VGG':loss_VGG.item()}, train_iter) + loss = loss_L1+loss_VGG + loss.backward() + optimizer.step() + previous_predframe_train = out.detach().cpu().numpy() + + # save network + if train_iter%opt.save_freq == 0 and train_iter != 0: + model_util.save(net, os.path.join('checkpoints',opt.savename,str(train_iter)+'.pth'), opt.use_gpu) + + # psnr + if train_iter%opt.psnr_freq ==0: + psnr = 0 + for i in range(len(out)): + psnr += impro.psnr(data.tensor2im(out,batch_index=i), data.tensor2im(ori_stream[:,:,opt.N],batch_index=i)) + TBGlobalWriter.add_scalars('psnr', {'train':psnr/len(out)}, train_iter) + + if train_iter % opt.showresult_freq == 0: + show_imgs = [] + for i in range(opt.showresult_num): + show_imgs += [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False,batch_index=i), + data.tensor2im(out,rgb2bgr = False,batch_index=i), + data.tensor2im(ori_stream[:,:,opt.N],rgb2bgr = False,batch_index=i)] + show_img = impro.splice(show_imgs, (opt.showresult_num,3)) + TBGlobalWriter.add_image('train', show_img,train_iter,dataformats='HWC') + + # eval + if (train_iter)%5 ==0: + ori_stream,mosaic_stream,last_frame = dataloader_eval.get_data() + ori_stream = data.to_tensor(ori_stream, opt.use_gpu) + mosaic_stream = data.to_tensor(mosaic_stream, opt.use_gpu) + if last_frame is None: + last_frame = data.to_tensor(previous_predframe_eval, opt.use_gpu) + else: + last_frame = data.to_tensor(last_frame, opt.use_gpu) + with torch.no_grad(): + out = net(mosaic_stream,last_frame) + loss_L1 = lossf_L1(out,ori_stream[:,:,opt.N]) + loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG + TBGlobalWriter.add_scalars('loss/eval', {'L1':loss_L1.item(),'VGG':loss_VGG.item()}, train_iter) + previous_predframe_eval = out.detach().cpu().numpy() + + #psnr + if (train_iter)%opt.psnr_freq ==0: + psnr = 0 + for i in range(len(out)): + psnr += impro.psnr(data.tensor2im(out,batch_index=i), data.tensor2im(ori_stream[:,:,opt.N],batch_index=i)) + TBGlobalWriter.add_scalars('psnr', {'eval':psnr/len(out)}, train_iter) + + if train_iter % opt.showresult_freq == 0: + show_imgs = [] + for i in range(opt.showresult_num): + show_imgs += [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False,batch_index=i), + data.tensor2im(out,rgb2bgr = False,batch_index=i), + data.tensor2im(ori_stream[:,:,opt.N],rgb2bgr = False,batch_index=i)] + show_img = impro.splice(show_imgs, (opt.showresult_num,3)) + TBGlobalWriter.add_image('eval', show_img,train_iter,dataformats='HWC') + t_end = time.time() + print('iter:{0:d} t:{1:.2f} l1:{2:.4f} vgg:{3:.4f} psnr:{4:.2f}'.format(train_iter,t_end-t_start, + loss_L1.item(),loss_VGG.item(),psnr/len(out)) ) + t_strat = time.time() + + # test + test_dir = '../../datasets/video_test' + if train_iter % opt.showresult_freq == 0 and os.path.isdir(test_dir): + show_imgs = [] + videos = os.listdir(test_dir) + sorted(videos) + for video in videos: + frames = os.listdir(os.path.join(test_dir,video,'image')) + sorted(frames) + mosaic_stream = [] + for i in range(opt.T): + _mosaic = impro.imread(os.path.join(test_dir,video,'image',frames[i*opt.S]),loadsize=opt.finesize,rgb=True) + mosaic_stream.append(_mosaic) + previous = impro.imread(os.path.join(test_dir,video,'image',frames[opt.N*opt.S-1]),loadsize=opt.finesize,rgb=True) + mosaic_stream = (np.array(mosaic_stream).astype(np.float32)/255.0-0.5)/0.5 + mosaic_stream = mosaic_stream.reshape(1,opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3)) + mosaic_stream = data.to_tensor(mosaic_stream, opt.use_gpu) + previous = data.im2tensor(previous,bgr2rgb = False, use_gpu = opt.use_gpu,use_transform = False, is0_1 = False) + with torch.no_grad(): + out = net(mosaic_stream,previous) + show_imgs+= [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False),data.tensor2im(out,rgb2bgr = False)] + + show_img = impro.splice(show_imgs, (len(videos),2)) + TBGlobalWriter.add_image('test', show_img,train_iter,dataformats='HWC') diff --git a/train/clean/train_old.py b/train/clean/train_old.py new file mode 100644 index 0000000..70efb41 --- /dev/null +++ b/train/clean/train_old.py @@ -0,0 +1,310 @@ +import os +import sys +sys.path.append("..") +sys.path.append("../..") +from cores import Options +opt = Options() + +import numpy as np +import cv2 +import random +import torch +import torch.nn as nn +import time +from multiprocessing import Process, Queue + +from util import mosaic,util,ffmpeg,filt,data +from util import image_processing as impro +from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model +import matplotlib +matplotlib.use('Agg') +from matplotlib import pyplot as plt +import torch.backends.cudnn as cudnn + +''' +--------------------------Get options-------------------------- +''' +opt.parser.add_argument('--N',type=int,default=25, help='') +opt.parser.add_argument('--lr',type=float,default=0.0002, help='') +opt.parser.add_argument('--beta1',type=float,default=0.5, help='') +opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan') +opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss') +opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model') +opt.parser.add_argument('--lambda_L1',type=float,default=100, help='') +opt.parser.add_argument('--lambda_gan',type=float,default=1, help='') +opt.parser.add_argument('--finesize',type=int,default=256, help='') +opt.parser.add_argument('--loadsize',type=int,default=286, help='') +opt.parser.add_argument('--batchsize',type=int,default=1, help='') +opt.parser.add_argument('--norm',type=str,default='instance', help='') +opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') +opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') +opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') +opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool') +opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data') + +opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') +opt.parser.add_argument('--maxiter',type=int,default=10000000, help='') +opt.parser.add_argument('--savefreq',type=int,default=10000, help='') +opt.parser.add_argument('--startiter',type=int,default=0, help='') +opt.parser.add_argument('--continue_train', action='store_true', help='') +opt.parser.add_argument('--savename',type=str,default='face', help='') + + +''' +--------------------------Init-------------------------- +''' +opt = opt.getparse() +dir_checkpoint = os.path.join('checkpoints/',opt.savename) +util.makedirs(dir_checkpoint) +util.writelog(os.path.join(dir_checkpoint,'loss.txt'), + str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt)) +cudnn.benchmark = True + +N = opt.N +loss_sum = [0.,0.,0.,0.,0.,0] +loss_plot = [[],[],[],[]] +item_plot = [] + +# list video dir +videonames = os.listdir(opt.dataset) +videonames.sort() +lengths = [];tmp = [] +print('Check dataset...') +for video in videonames: + if video != 'opt.txt': + video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image')) + lengths.append(len(video_images)) + tmp.append(video) +videonames = tmp +video_num = len(videonames) + +#--------------------------Init network-------------------------- +print('Init network...') +if opt.hd: + netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm) +else: + netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm) +netG.cuda() +loadmodel.show_paramsnumber(netG,'netG') + +if opt.gan: + if opt.hd: + netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True) + else: + netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm) + netD.cuda() + netD.train() + +#--------------------------continue train-------------------------- +if opt.continue_train: + if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')): + opt.continue_train = False + print('can not load last_G, training on init weight.') +if opt.continue_train: + netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth'))) + if opt.gan: + netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth'))) + f = open(os.path.join(dir_checkpoint,'iter'),'r') + opt.startiter = int(f.read()) + f.close() + +#--------------------------optimizer & loss-------------------------- +optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) +criterion_L1 = nn.L1Loss() +criterion_L2 = nn.MSELoss() +if opt.gan: + optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999)) + if opt.hd: + criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() + criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt) + criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu]) + else: + criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda() + +''' +--------------------------preload data & data pool-------------------------- +''' +print('Preloading data, please wait...') +def preload(pool): + cnt = 0 + input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize) + ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize) + while 1: + try: + for i in range(opt.batchsize): + video_index = random.randint(0,video_num-1) + videoname = videonames[video_index] + img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1) + input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt) + cnt += 1 + pool.put([input_imgs,ground_trues]) + except Exception as e: + print("Error:",videoname,e) +pool = Queue(opt.image_pool) +for i in range(opt.load_process): + p = Process(target=preload,args=(pool,)) + p.daemon = True + p.start() + +''' +--------------------------train-------------------------- +''' +util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py')) +util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py')) +netG.train() +time_start=time.time() +print("Begin training...") +for iter in range(opt.startiter+1,opt.maxiter): + + inputdata,target = pool.get() + inputdata,target = inputdata.cuda(),target.cuda() + + if opt.gan: + # compute fake images: G(A) + pred = netG(inputdata) + real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] + + # --------------------update D-------------------- + pix2pix_model.set_requires_grad(netD,True) + optimizer_D.zero_grad() + # Fake + fake_AB = torch.cat((real_A, pred), 1) + pred_fake = netD(fake_AB.detach()) + loss_D_fake = criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((real_A, target), 1) + pred_real = netD(real_AB) + loss_D_real = criterionGAN(pred_real, True) + # combine loss and calculate gradients + loss_D = (loss_D_fake + loss_D_real) * 0.5 + loss_sum[4] += loss_D_fake.item() + loss_sum[5] += loss_D_real.item() + # udpate D's weights + loss_D.backward() + optimizer_D.step() + + # --------------------update G-------------------- + pix2pix_model.set_requires_grad(netD,False) + optimizer_G.zero_grad() + + # First, G(A) should fake the discriminator + fake_AB = torch.cat((real_A, pred), 1) + pred_fake = netD(fake_AB) + loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan + + # combine loss and calculate gradients + if opt.l2: + loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1 + else: + loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1 + + if opt.hd: + real_AB = torch.cat((real_A, target), 1) + pred_real = netD(real_AB) + loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real) + loss_VGG = criterionVGG(pred, target) * opt.lambda_feat + loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG + else: + loss_G = loss_G_GAN + loss_G_L1 + loss_sum[0] += loss_G_L1.item() + loss_sum[1] += loss_G_GAN.item() + loss_sum[2] += loss_G_GAN_Feat.item() + loss_sum[3] += loss_VGG.item() + + # udpate G's weights + loss_G.backward() + optimizer_G.step() + + else: + pred = netG(inputdata) + if opt.l2: + loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1 + else: + loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1 + loss_sum[0] += loss_G_L1.item() + + optimizer_G.zero_grad() + loss_G_L1.backward() + optimizer_G.step() + + # save train result + if (iter+1)%1000 == 0: + try: + data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], + target, pred, os.path.join(dir_checkpoint,'result_train.jpg')) + except Exception as e: + print(e) + + # plot + if (iter+1)%1000 == 0: + time_end = time.time() + #if opt.gan: + savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format( + iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000) + util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True) + if (iter+1)/1000 >= 10: + for i in range(4):loss_plot[i].append(loss_sum[i]/1000) + item_plot.append(iter+1) + try: + labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss'] + for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i]) + plt.xlabel('iter') + plt.legend(loc=1) + plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) + plt.close() + except Exception as e: + print("error:",e) + + loss_sum = [0.,0.,0.,0.,0.,0.] + time_start=time.time() + + # save network + if (iter+1)%(opt.savefreq//10) == 0: + torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth')) + if opt.gan: + torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth')) + if opt.use_gpu !=-1 : + netG.cuda() + if opt.gan: + netD.cuda() + f = open(os.path.join(dir_checkpoint,'iter'),'w+') + f.write(str(iter+1)) + f.close() + + if (iter+1)%opt.savefreq == 0: + os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth')) + if opt.gan: + os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth')) + print('network saved.') + + #test + if (iter+1)%opt.savefreq == 0: + if os.path.isdir('./test'): + netG.eval() + + test_names = os.listdir('./test') + test_names.sort() + result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8') + + for cnt,test_name in enumerate(test_names,0): + img_names = os.listdir(os.path.join('./test',test_name,'image')) + img_names.sort() + inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8') + for i in range(0,N): + img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) + img = impro.resize(img,opt.finesize) + inputdata[:,:,i*3:(i+1)*3] = img + + mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') + mask = impro.resize(mask,opt.finesize) + mask = impro.mask_threshold(mask,15,128) + inputdata[:,:,-1] = mask + result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] + inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False) + pred = netG(inputdata) + + pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False) + result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred + + cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result) + netG.train() \ No newline at end of file diff --git a/util/data.py b/util/data.py index c76dfa3..8884047 100755 --- a/util/data.py +++ b/util/data.py @@ -10,11 +10,18 @@ transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) ] -) +) -def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False): +def to_tensor(data,gpu_id): + data = torch.from_numpy(data) + if gpu_id != '-1': + data = data.cuda() + return data + + +def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0): image_tensor =image_tensor.data - image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = image_tensor[batch_index].cpu().float().numpy() if not is0_1: image_numpy = (image_numpy + 1)/2.0 @@ -58,7 +65,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape = image_tensor = torch.from_numpy(image_numpy).float() if reshape: image_tensor = image_tensor.reshape(1,ch,h,w) - if use_gpu != -1: + if use_gpu != '-1': image_tensor = image_tensor.cuda() return image_tensor @@ -68,53 +75,53 @@ def shuffledata(data,target): np.random.set_state(state) np.random.shuffle(target) -def random_transform_video(src,target,finesize,N): - #random blur - if random.random()<0.2: - h,w = src.shape[:2] - src = src[:8*(h//8),:8*(w//8)] - Q_ran = random.randint(1,15) - src[:,:,:3*N] = impro.dctblur(src[:,:,:3*N],Q_ran) - target = impro.dctblur(target,Q_ran) +# def random_transform_video(src,target,finesize,N): +# #random blur +# if random.random()<0.2: +# h,w = src.shape[:2] +# src = src[:8*(h//8),:8*(w//8)] +# Q_ran = random.randint(1,15) +# src[:,:,:3*N] = impro.dctblur(src[:,:,:3*N],Q_ran) +# target = impro.dctblur(target,Q_ran) - #random crop - h,w = target.shape[:2] - h_move = int((h-finesize)*random.random()) - w_move = int((w-finesize)*random.random()) - target = target[h_move:h_move+finesize,w_move:w_move+finesize,:] - src = src[h_move:h_move+finesize,w_move:w_move+finesize,:] +# #random crop +# h,w = target.shape[:2] +# h_move = int((h-finesize)*random.random()) +# w_move = int((w-finesize)*random.random()) +# target = target[h_move:h_move+finesize,w_move:w_move+finesize,:] +# src = src[h_move:h_move+finesize,w_move:w_move+finesize,:] - #random flip - if random.random()<0.5: - src = src[:,::-1,:] - target = target[:,::-1,:] +# #random flip +# if random.random()<0.5: +# src = src[:,::-1,:] +# target = target[:,::-1,:] - #random color - alpha = random.uniform(-0.1,0.1) - beta = random.uniform(-0.1,0.1) - b = random.uniform(-0.05,0.05) - g = random.uniform(-0.05,0.05) - r = random.uniform(-0.05,0.05) - for i in range(N): - src[:,:,i*3:(i+1)*3] = impro.color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r) - target = impro.color_adjust(target,alpha,beta,b,g,r) - - #random resize blur - if random.random()<0.5: - interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4] - size_ran = random.uniform(0.7,1.5) - interpolation_up = interpolations[random.randint(0,2)] - interpolation_down =interpolations[random.randint(0,2)] +# #random color +# alpha = random.uniform(-0.1,0.1) +# beta = random.uniform(-0.1,0.1) +# b = random.uniform(-0.05,0.05) +# g = random.uniform(-0.05,0.05) +# r = random.uniform(-0.05,0.05) +# for i in range(N): +# src[:,:,i*3:(i+1)*3] = impro.color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r) +# target = impro.color_adjust(target,alpha,beta,b,g,r) + +# #random resize blur +# if random.random()<0.5: +# interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4] +# size_ran = random.uniform(0.7,1.5) +# interpolation_up = interpolations[random.randint(0,2)] +# interpolation_down =interpolations[random.randint(0,2)] - tmp = cv2.resize(src[:,:,:3*N], (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up) - src[:,:,:3*N] = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down) +# tmp = cv2.resize(src[:,:,:3*N], (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up) +# src[:,:,:3*N] = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down) - tmp = cv2.resize(target, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up) - target = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down) +# tmp = cv2.resize(target, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up) +# target = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down) - return src,target +# return src,target -def random_transform_single(img,out_shape): +def random_transform_single_mask(img,out_shape): out_h,out_w = out_shape img = cv2.resize(img,(int(out_w*random.uniform(1.1, 1.5)),int(out_h*random.uniform(1.1, 1.5)))) h,w = img.shape[:2] @@ -130,7 +137,72 @@ def random_transform_single(img,out_shape): img = cv2.resize(img,(out_w,out_h)) return img -def random_transform_image(img,mask,finesize,test_flag = False): +def get_transform_params(): + scale_flag = np.random.random()<0.2 + crop_flag = True + rotat_flag = np.random.random()<0.2 + color_flag = True + flip_flag = np.random.random()<0.2 + blur_flag = np.random.random()<0.1 + flag_dict = {'scale':scale_flag,'crop':crop_flag,'rotat':rotat_flag,'color':color_flag, + 'flip':flip_flag,'blur':blur_flag} + + scale_rate = np.random.uniform(0.9,1.1) + crop_rate = [np.random.random(),np.random.random()] + rotat_rate = np.random.random() + color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05), + np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)] + flip_rate = np.random.random() + blur_rate = np.random.randint(1,15) + rate_dict = {'scale':scale_rate,'crop':crop_rate,'rotat':rotat_rate,'color':color_rate, + 'flip':flip_rate,'blur':blur_rate} + + return {'flag':flag_dict,'rate':rate_dict} + +def random_transform_single_image(img,finesize,params=None,test_flag = False): + if params is None: + params = get_transform_params() + if test_flag: + params['flag']['scale'] = False + if params['flag']['scale']: + h,w = img.shape[:2] + loadsize = min((h,w)) + a = (float(h)/float(w))*params['rate']['scale'] + if horiginal @@ -43,6 +43,9 @@ def imread(file_path,mod = 'normal',loadsize = 0): if loadsize != 0: img = resize(img, loadsize, interpolation=cv2.INTER_CUBIC) + if rgb and img.ndim==3: + img = img[:,:,::-1] + return img def imwrite(file_path,img): @@ -252,4 +255,27 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather): img_result = img_origin.copy() img_result = (img_origin*(1-mask)+img_tmp*mask).astype('uint8') - return img_result \ No newline at end of file + return img_result + +def psnr(img1,img2): + mse = np.mean((img1/255.0-img2/255.0)**2) + if mse < 1e-10: + return 100 + psnr_v = 20*np.log10(1/np.sqrt(mse)) + return psnr_v + +def splice(imgs,splice_shape): + '''Stitching multiple images, all imgs must have the same size + imgs : [img1,img2,img3,img4] + splice_shape: (2,2) + ''' + h,w,ch = imgs[0].shape + output = np.zeros((h*splice_shape[0],w*splice_shape[1],ch),np.uint8) + cnt = 0 + for i in range(splice_shape[0]): + for j in range(splice_shape[1]): + if cnt < len(imgs): + output[h*i:h*(i+1),w*j:w*(j+1)] = imgs[cnt] + cnt += 1 + return output + diff --git a/util/util.py b/util/util.py index 571974e..4952df8 100755 --- a/util/util.py +++ b/util/util.py @@ -1,4 +1,6 @@ import os +import random +import string import shutil def Traversal(filedir): @@ -10,6 +12,9 @@ def Traversal(filedir): Traversal(dir) return file_list +def randomstr(num): + return ''.join(random.sample(string.ascii_letters + string.digits, num)) + def is_img(path): ext = os.path.splitext(path)[1] ext = ext.lower() -- GitLab