提交 972dcc87 编写于 作者: H hypox64

BVDNet

上级 538ccbcd
......@@ -11,7 +11,7 @@ class Options():
def initialize(self):
#base
self.parser.add_argument('--use_gpu', type=int,default=0, help='if -1, use cpu')
self.parser.add_argument('--use_gpu', type=str,default='0', help='if -1, use cpu')
self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path')
self.parser.add_argument('-ss', '--start_time', type=str, default='00:00:00',help='start position of video, default is the beginning of video')
self.parser.add_argument('-t', '--last_time', type=str, default='00:00:00',help='duration of the video, default is the entire video')
......@@ -58,13 +58,15 @@ class Options():
model_name = os.path.basename(self.opt.model_path)
self.opt.temp_dir = os.path.join(self.opt.temp_dir, 'DeepMosaics_temp')
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu)
import torch
if torch.cuda.is_available() and self.opt.use_gpu > -1:
pass
else:
self.opt.use_gpu = -1
if self.opt.use_gpu != '-1':
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu)
import torch
if not torch.cuda.is_available():
self.opt.use_gpu = '-1'
# else:
# self.opt.use_gpu = '-1'
if test_flag:
if not os.path.exists(self.opt.media_path):
......
......@@ -87,7 +87,7 @@ for fold in range(opt.fold):
mask = mask_drawn
if 'irregular' in opt.mod:
mask_irr = impro.imread(irrpaths[random.randint(0,12000-1)],'gray')
mask_irr = data.random_transform_single(mask_irr, (img.shape[0],img.shape[1]))
mask_irr = data.random_transform_single_mask(mask_irr, (img.shape[0],img.shape[1]))
mask = mask_irr
if 'network' in opt.mod:
mask_net = runmodel.get_ROI_position(img,net,opt,keepsize=True)[0]
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pix2pixHD_model import *
class Encoder2d(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm2d):
super(Encoder2d, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.ReflectionPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0),
norm_layer(ngf * mult * 2), activation]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class Encoder3d(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm3d):
super(Encoder3d, self).__init__()
activation = nn.ReLU(True)
model = [nn.Conv3d(input_nc, ngf, kernel_size=3, padding=1), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class BVDNet(nn.Module):
def __init__(self, N, n_downsampling=3, n_blocks=1, input_nc=3, output_nc=3):
super(BVDNet, self).__init__()
ngf = 64
padding_type = 'reflect'
norm_layer = nn.BatchNorm2d
self.N = N
# encoder
self.encoder3d = Encoder3d(input_nc,64,n_downsampling)
self.encoder2d = Encoder2d(input_nc,64,n_downsampling)
### resnet blocks
self.blocks = []
mult = 2**n_downsampling
for i in range(n_blocks):
self.blocks += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=nn.ReLU(True), norm_layer=norm_layer)]
self.blocks = nn.Sequential(*self.blocks)
### decoder
self.decoder = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
# self.decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
# norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
self.decoder += [ nn.Upsample(scale_factor = 2, mode='nearest'),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
self.decoder += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.decoder = nn.Sequential(*self.decoder)
self.limiter = nn.Tanh()
def forward(self, stream, last):
this_shortcut = stream[:,:,self.N]
stream = self.encoder3d(stream)
stream = stream.reshape(stream.size(0),stream.size(1),stream.size(3),stream.size(4))
# print(stream.shape)
last = self.encoder2d(last)
x = stream + last
x = self.blocks(x)
x = self.decoder(x)
x = x+this_shortcut
x = self.limiter(x)
#print(x.shape)
# print(stream.shape,last.shape)
return x
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19()
if gpu_ids != '-1' and len(gpu_ids) == 1:
self.vgg.cuda()
elif gpu_ids != '-1' and len(gpu_ids) > 1:
self.vgg = nn.DataParallel(self.vgg)
self.vgg.cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
import torch
import torch.nn as nn
def save(net,path,gpu_id):
if isinstance(net, nn.DataParallel):
torch.save(net.module.cpu().state_dict(),path)
else:
torch.save(net.cpu().state_dict(),path)
if gpu_id != '-1':
net.cuda()
\ No newline at end of file
......@@ -68,7 +68,7 @@ def loadimage(imagepaths,maskpaths,opt,test_flag = False):
for i in range(len(imagepaths)):
img = impro.resize(impro.imread(imagepaths[i]),opt.loadsize)
mask = impro.resize(impro.imread(maskpaths[i],mod = 'gray'),opt.loadsize)
img,mask = data.random_transform_image(img, mask, opt.finesize, test_flag)
img,mask = data.random_transform_pair_image(img, mask, opt.finesize, test_flag)
images[i] = (img.transpose((2, 0, 1))/255.0)
masks[i] = (mask.reshape(1,1,opt.finesize,opt.finesize)/255.0)
images = Totensor(images,opt.use_gpu)
......
此差异已折叠。
import os
import sys
sys.path.append("..")
sys.path.append("../..")
from cores import Options
opt = Options()
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
from multiprocessing import Process, Queue
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn
'''
--------------------------Get options--------------------------
'''
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool')
opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data')
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continue_train', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='face', help='')
'''
--------------------------Init--------------------------
'''
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
util.makedirs(dir_checkpoint)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'),
str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
cudnn.benchmark = True
N = opt.N
loss_sum = [0.,0.,0.,0.,0.,0]
loss_plot = [[],[],[],[]]
item_plot = []
# list video dir
videonames = os.listdir(opt.dataset)
videonames.sort()
lengths = [];tmp = []
print('Check dataset...')
for video in videonames:
if video != 'opt.txt':
video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image'))
lengths.append(len(video_images))
tmp.append(video)
videonames = tmp
video_num = len(videonames)
#--------------------------Init network--------------------------
print('Init network...')
if opt.hd:
netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
netG.cuda()
loadmodel.show_paramsnumber(netG,'netG')
if opt.gan:
if opt.hd:
netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True)
else:
netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
netD.cuda()
netD.train()
#--------------------------continue train--------------------------
if opt.continue_train:
if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
opt.continue_train = False
print('can not load last_G, training on init weight.')
if opt.continue_train:
netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
if opt.gan:
netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
f = open(os.path.join(dir_checkpoint,'iter'),'r')
opt.startiter = int(f.read())
f.close()
#--------------------------optimizer & loss--------------------------
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
if opt.gan:
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
if opt.hd:
criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda()
criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt)
criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu])
else:
criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
'''
--------------------------preload data & data pool--------------------------
'''
print('Preloading data, please wait...')
def preload(pool):
cnt = 0
input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize)
ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize)
while 1:
try:
for i in range(opt.batchsize):
video_index = random.randint(0,video_num-1)
videoname = videonames[video_index]
img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt)
cnt += 1
pool.put([input_imgs,ground_trues])
except Exception as e:
print("Error:",videoname,e)
pool = Queue(opt.image_pool)
for i in range(opt.load_process):
p = Process(target=preload,args=(pool,))
p.daemon = True
p.start()
'''
--------------------------train--------------------------
'''
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
netG.train()
time_start=time.time()
print("Begin training...")
for iter in range(opt.startiter+1,opt.maxiter):
inputdata,target = pool.get()
inputdata,target = inputdata.cuda(),target.cuda()
if opt.gan:
# compute fake images: G(A)
pred = netG(inputdata)
real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
# --------------------update D--------------------
pix2pix_model.set_requires_grad(netD,True)
optimizer_D.zero_grad()
# Fake
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((real_A, target), 1)
pred_real = netD(real_AB)
loss_D_real = criterionGAN(pred_real, True)
# combine loss and calculate gradients
loss_D = (loss_D_fake + loss_D_real) * 0.5
loss_sum[4] += loss_D_fake.item()
loss_sum[5] += loss_D_real.item()
# udpate D's weights
loss_D.backward()
optimizer_D.step()
# --------------------update G--------------------
pix2pix_model.set_requires_grad(netD,False)
optimizer_G.zero_grad()
# First, G(A) should fake the discriminator
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
# combine loss and calculate gradients
if opt.l2:
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
else:
loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
if opt.hd:
real_AB = torch.cat((real_A, target), 1)
pred_real = netD(real_AB)
loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real)
loss_VGG = criterionVGG(pred, target) * opt.lambda_feat
loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG
else:
loss_G = loss_G_GAN + loss_G_L1
loss_sum[0] += loss_G_L1.item()
loss_sum[1] += loss_G_GAN.item()
loss_sum[2] += loss_G_GAN_Feat.item()
loss_sum[3] += loss_VGG.item()
# udpate G's weights
loss_G.backward()
optimizer_G.step()
else:
pred = netG(inputdata)
if opt.l2:
loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
else:
loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
loss_sum[0] += loss_G_L1.item()
optimizer_G.zero_grad()
loss_G_L1.backward()
optimizer_G.step()
# save train result
if (iter+1)%1000 == 0:
try:
data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
target, pred, os.path.join(dir_checkpoint,'result_train.jpg'))
except Exception as e:
print(e)
# plot
if (iter+1)%1000 == 0:
time_end = time.time()
#if opt.gan:
savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format(
iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
if (iter+1)/1000 >= 10:
for i in range(4):loss_plot[i].append(loss_sum[i]/1000)
item_plot.append(iter+1)
try:
labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss']
for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i])
plt.xlabel('iter')
plt.legend(loc=1)
plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
plt.close()
except Exception as e:
print("error:",e)
loss_sum = [0.,0.,0.,0.,0.,0.]
time_start=time.time()
# save network
if (iter+1)%(opt.savefreq//10) == 0:
torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
if opt.gan:
torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
if opt.use_gpu !=-1 :
netG.cuda()
if opt.gan:
netD.cuda()
f = open(os.path.join(dir_checkpoint,'iter'),'w+')
f.write(str(iter+1))
f.close()
if (iter+1)%opt.savefreq == 0:
os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth'))
if opt.gan:
os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth'))
print('network saved.')
#test
if (iter+1)%opt.savefreq == 0:
if os.path.isdir('./test'):
netG.eval()
test_names = os.listdir('./test')
test_names.sort()
result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
img_names.sort()
inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,opt.finesize)
inputdata[:,:,i*3:(i+1)*3] = img
mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
mask = impro.resize(mask,opt.finesize)
mask = impro.mask_threshold(mask,15,128)
inputdata[:,:,-1] = mask
result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
pred = netG(inputdata)
pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
netG.train()
\ No newline at end of file
......@@ -10,11 +10,18 @@ transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
]
)
)
def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False):
def to_tensor(data,gpu_id):
data = torch.from_numpy(data)
if gpu_id != '-1':
data = data.cuda()
return data
def tensor2im(image_tensor, imtype=np.uint8, gray=False, rgb2bgr = True ,is0_1 = False, batch_index=0):
image_tensor =image_tensor.data
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = image_tensor[batch_index].cpu().float().numpy()
if not is0_1:
image_numpy = (image_numpy + 1)/2.0
......@@ -58,7 +65,7 @@ def im2tensor(image_numpy, imtype=np.uint8, gray=False,bgr2rgb = True, reshape =
image_tensor = torch.from_numpy(image_numpy).float()
if reshape:
image_tensor = image_tensor.reshape(1,ch,h,w)
if use_gpu != -1:
if use_gpu != '-1':
image_tensor = image_tensor.cuda()
return image_tensor
......@@ -68,53 +75,53 @@ def shuffledata(data,target):
np.random.set_state(state)
np.random.shuffle(target)
def random_transform_video(src,target,finesize,N):
#random blur
if random.random()<0.2:
h,w = src.shape[:2]
src = src[:8*(h//8),:8*(w//8)]
Q_ran = random.randint(1,15)
src[:,:,:3*N] = impro.dctblur(src[:,:,:3*N],Q_ran)
target = impro.dctblur(target,Q_ran)
# def random_transform_video(src,target,finesize,N):
# #random blur
# if random.random()<0.2:
# h,w = src.shape[:2]
# src = src[:8*(h//8),:8*(w//8)]
# Q_ran = random.randint(1,15)
# src[:,:,:3*N] = impro.dctblur(src[:,:,:3*N],Q_ran)
# target = impro.dctblur(target,Q_ran)
#random crop
h,w = target.shape[:2]
h_move = int((h-finesize)*random.random())
w_move = int((w-finesize)*random.random())
target = target[h_move:h_move+finesize,w_move:w_move+finesize,:]
src = src[h_move:h_move+finesize,w_move:w_move+finesize,:]
# #random crop
# h,w = target.shape[:2]
# h_move = int((h-finesize)*random.random())
# w_move = int((w-finesize)*random.random())
# target = target[h_move:h_move+finesize,w_move:w_move+finesize,:]
# src = src[h_move:h_move+finesize,w_move:w_move+finesize,:]
#random flip
if random.random()<0.5:
src = src[:,::-1,:]
target = target[:,::-1,:]
# #random flip
# if random.random()<0.5:
# src = src[:,::-1,:]
# target = target[:,::-1,:]
#random color
alpha = random.uniform(-0.1,0.1)
beta = random.uniform(-0.1,0.1)
b = random.uniform(-0.05,0.05)
g = random.uniform(-0.05,0.05)
r = random.uniform(-0.05,0.05)
for i in range(N):
src[:,:,i*3:(i+1)*3] = impro.color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
target = impro.color_adjust(target,alpha,beta,b,g,r)
#random resize blur
if random.random()<0.5:
interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4]
size_ran = random.uniform(0.7,1.5)
interpolation_up = interpolations[random.randint(0,2)]
interpolation_down =interpolations[random.randint(0,2)]
# #random color
# alpha = random.uniform(-0.1,0.1)
# beta = random.uniform(-0.1,0.1)
# b = random.uniform(-0.05,0.05)
# g = random.uniform(-0.05,0.05)
# r = random.uniform(-0.05,0.05)
# for i in range(N):
# src[:,:,i*3:(i+1)*3] = impro.color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
# target = impro.color_adjust(target,alpha,beta,b,g,r)
# #random resize blur
# if random.random()<0.5:
# interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4]
# size_ran = random.uniform(0.7,1.5)
# interpolation_up = interpolations[random.randint(0,2)]
# interpolation_down =interpolations[random.randint(0,2)]
tmp = cv2.resize(src[:,:,:3*N], (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
src[:,:,:3*N] = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
# tmp = cv2.resize(src[:,:,:3*N], (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
# src[:,:,:3*N] = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
tmp = cv2.resize(target, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
target = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
# tmp = cv2.resize(target, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
# target = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
return src,target
# return src,target
def random_transform_single(img,out_shape):
def random_transform_single_mask(img,out_shape):
out_h,out_w = out_shape
img = cv2.resize(img,(int(out_w*random.uniform(1.1, 1.5)),int(out_h*random.uniform(1.1, 1.5))))
h,w = img.shape[:2]
......@@ -130,7 +137,72 @@ def random_transform_single(img,out_shape):
img = cv2.resize(img,(out_w,out_h))
return img
def random_transform_image(img,mask,finesize,test_flag = False):
def get_transform_params():
scale_flag = np.random.random()<0.2
crop_flag = True
rotat_flag = np.random.random()<0.2
color_flag = True
flip_flag = np.random.random()<0.2
blur_flag = np.random.random()<0.1
flag_dict = {'scale':scale_flag,'crop':crop_flag,'rotat':rotat_flag,'color':color_flag,
'flip':flip_flag,'blur':blur_flag}
scale_rate = np.random.uniform(0.9,1.1)
crop_rate = [np.random.random(),np.random.random()]
rotat_rate = np.random.random()
color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),
np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)]
flip_rate = np.random.random()
blur_rate = np.random.randint(1,15)
rate_dict = {'scale':scale_rate,'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,
'flip':flip_rate,'blur':blur_rate}
return {'flag':flag_dict,'rate':rate_dict}
def random_transform_single_image(img,finesize,params=None,test_flag = False):
if params is None:
params = get_transform_params()
if test_flag:
params['flag']['scale'] = False
if params['flag']['scale']:
h,w = img.shape[:2]
loadsize = min((h,w))
a = (float(h)/float(w))*params['rate']['scale']
if h<w:
img = cv2.resize(img, (int(loadsize/a),loadsize))
else:
img = cv2.resize(img, (loadsize,int(loadsize*a)))
if params['flag']['crop']:
h,w = img.shape[:2]
h_move = int((h-finesize)*params['rate']['crop'][0])
w_move = int((w-finesize)*params['rate']['crop'][1])
img = img[h_move:h_move+finesize,w_move:w_move+finesize]
if test_flag:
return img
if params['flag']['rotat']:
h,w = img.shape[:2]
M = cv2.getRotationMatrix2D((w/2,h/2),90*int(4*params['rate']['rotat']),1)
img = cv2.warpAffine(img,M,(w,h))
if params['flag']['color']:
img = impro.color_adjust(img,params['rate']['color'][0],params['rate']['color'][1],
params['rate']['color'][2],params['rate']['color'][3],params['rate']['color'][4])
if params['flag']['flip']:
img = img[:,::-1,:]
if params['flag']['blur']:
img = impro.dctblur(img,params['rate']['blur'])
#check shape
if img.shape[0]!= finesize or img.shape[1]!= finesize:
img = cv2.resize(img,(finesize,finesize))
print('warning! shape error.')
return img
def random_transform_pair_image(img,mask,finesize,test_flag = False):
#random scale
if random.random()<0.5:
h,w = img.shape[:2]
......@@ -191,28 +263,28 @@ def random_transform_image(img,mask,finesize,test_flag = False):
return img,mask
def load_train_video(videoname,img_index,opt):
N = opt.N
input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
# this frame
this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
input_img[:,:,-1] = this_mask
#print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
ground_true = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask)
start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3)
# merge other frame
for i in range(0,N):
img = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i])
input_img[:,:,i*3:(i+1)*3] = img_mosaic
# to tensor
input_img,ground_true = random_transform_video(input_img,ground_true,opt.finesize,N)
input_img = im2tensor(input_img,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
ground_true = im2tensor(ground_true,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
# def load_train_video(videoname,img_index,opt):
# N = opt.N
# input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
# # this frame
# this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
# input_img[:,:,-1] = this_mask
# #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
# ground_true = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
# mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask)
# start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3)
# # merge other frame
# for i in range(0,N):
# img = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
# mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
# img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i])
# input_img[:,:,i*3:(i+1)*3] = img_mosaic
# # to tensor
# input_img,ground_true = random_transform_video(input_img,ground_true,opt.finesize,N)
# input_img = im2tensor(input_img,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
# ground_true = im2tensor(ground_true,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
return input_img,ground_true
# return input_img,ground_true
def showresult(img1,img2,img3,name,is0_1 = False):
......
......@@ -17,7 +17,7 @@ DCT_Q = np.array([[8,16,19,22,26,27,29,34],
[26,27,29,34,38,46,56,59],
[27,29,35,38,46,56,69,83]])
def imread(file_path,mod = 'normal',loadsize = 0):
def imread(file_path,mod = 'normal',loadsize = 0, rgb=False):
'''
mod: 'normal' | 'gray' | 'all'
loadsize: 0->original
......@@ -43,6 +43,9 @@ def imread(file_path,mod = 'normal',loadsize = 0):
if loadsize != 0:
img = resize(img, loadsize, interpolation=cv2.INTER_CUBIC)
if rgb and img.ndim==3:
img = img[:,:,::-1]
return img
def imwrite(file_path,img):
......@@ -252,4 +255,27 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather):
img_result = img_origin.copy()
img_result = (img_origin*(1-mask)+img_tmp*mask).astype('uint8')
return img_result
\ No newline at end of file
return img_result
def psnr(img1,img2):
mse = np.mean((img1/255.0-img2/255.0)**2)
if mse < 1e-10:
return 100
psnr_v = 20*np.log10(1/np.sqrt(mse))
return psnr_v
def splice(imgs,splice_shape):
'''Stitching multiple images, all imgs must have the same size
imgs : [img1,img2,img3,img4]
splice_shape: (2,2)
'''
h,w,ch = imgs[0].shape
output = np.zeros((h*splice_shape[0],w*splice_shape[1],ch),np.uint8)
cnt = 0
for i in range(splice_shape[0]):
for j in range(splice_shape[1]):
if cnt < len(imgs):
output[h*i:h*(i+1),w*j:w*(j+1)] = imgs[cnt]
cnt += 1
return output
import os
import random
import string
import shutil
def Traversal(filedir):
......@@ -10,6 +12,9 @@ def Traversal(filedir):
Traversal(dir)
return file_list
def randomstr(num):
return ''.join(random.sample(string.ascii_letters + string.digits, num))
def is_img(path):
ext = os.path.splitext(path)[1]
ext = ext.lower()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册