提交 48c032b1 编写于 作者: H hypox64

Readly to add gan

上级 4c6b29b4
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pix2pixHD_model import *
from .model_util import *
from models import model_util
class UpBlock(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
......@@ -17,14 +17,10 @@ class UpBlock(nn.Module):
# Blur(out_channel),
)
def forward(self, input):
outup = self.convup(input)
return outup
class Encoder2d(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, activation = nn.LeakyReLU(0.2)):
super(Encoder2d, self).__init__()
......@@ -52,21 +48,19 @@ class Encoder3d(nn.Module):
mult = 2**i
model += [ SpectralNorm(nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
activation]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class BVDNet(nn.Module):
def __init__(self, N, n_downsampling=3, n_blocks=1, input_nc=3, output_nc=3,activation=nn.LeakyReLU(0.2)):
def __init__(self, N=2, n_downsampling=3, n_blocks=4, input_nc=3, output_nc=3,activation=nn.LeakyReLU(0.2)):
super(BVDNet, self).__init__()
ngf = 64
padding_type = 'reflect'
self.N = N
# encoder
### encoder
self.encoder3d = Encoder3d(input_nc,64,n_downsampling,activation)
self.encoder2d = Encoder2d(input_nc,64,n_downsampling,activation)
......@@ -82,13 +76,6 @@ class BVDNet(nn.Module):
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
self.decoder += [UpBlock(ngf * mult, int(ngf * mult / 2))]
# self.decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
# norm_layer(int(ngf * mult / 2)), activation]
# self.decoder += [ nn.Upsample(scale_factor = 2, mode='nearest'),
# nn.ReflectionPad2d(1),
# nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
# norm_layer(int(ngf * mult / 2)),
# activation]
self.decoder += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.decoder = nn.Sequential(*self.decoder)
self.limiter = nn.Tanh()
......@@ -97,68 +84,124 @@ class BVDNet(nn.Module):
this_shortcut = stream[:,:,self.N]
stream = self.encoder3d(stream)
stream = stream.reshape(stream.size(0),stream.size(1),stream.size(3),stream.size(4))
# print(stream.shape)
previous = self.encoder2d(previous)
x = stream + previous
x = self.blocks(x)
x = self.decoder(x)
x = x+this_shortcut
x = self.limiter(x)
#print(x.shape)
# print(stream.shape,previous.shape)
return x
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19()
if gpu_ids != '-1' and len(gpu_ids) == 1:
self.vgg.cuda()
elif gpu_ids != '-1' and len(gpu_ids) > 1:
self.vgg = nn.DataParallel(self.vgg)
self.vgg.cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
def define_G(N=2, n_blocks=1, gpu_id='-1'):
netG = BVDNet(N = N, n_blocks=n_blocks)
if gpu_id != '-1' and len(gpu_id) == 1:
netG.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netG = nn.DataParallel(netG)
netG.cuda()
# netG.apply(model_util.init_weights)
return netG
################################Discriminator################################
def define_D(input_nc, ndf, n_layers_D, use_sigmoid=False, num_D=1, gpu_id='-1'):
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, num_D)
if gpu_id != '-1' and len(gpu_id) == 1:
netD.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netD = nn.DataParallel(netD)
netD.cuda()
netD.apply(model_util.init_weights)
return netD
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, num_D=3):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, use_sigmoid)
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False):
super(NLayerDiscriminator, self).__init__()
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.2)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw)),
nn.LeakyReLU(0.2)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
return self.model(input)
class GANLoss(nn.Module):
def __init__(self, mode='D'):
super(GANLoss, self).__init__()
if mode == 'D':
self.lossf = model_util.HingeLossD()
elif mode == 'G':
self.lossf = model_util.HingeLossG()
self.mode = mode
def forward(self, dis_fake = None, dis_real = None):
if isinstance(dis_fake, list):
weight = 2**len(dis_fake)
if self.mode == 'D':
loss = 0
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i],dis_real[i])
elif self.mode =='G':
loss = 0
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i])
return loss
else:
if self.mode == 'D':
return self.lossf(dis_fake,dis_real)
elif self.mode =='G':
return self.lossf(dis_fake)
import functools
from math import exp
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
import functools
from torchvision import models
def save(net,path,gpu_id):
if isinstance(net, nn.DataParallel):
......@@ -13,6 +17,7 @@ def save(net,path,gpu_id):
if gpu_id != '-1':
net.cuda()
################################## initialization ##################################
def get_norm_layer(norm_type='instance',mod = '2d'):
if norm_type == 'batch':
if mod == '2d':
......@@ -51,9 +56,10 @@ def init_weights(net, init_type='normal', gain=0.02):
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
# print('initialize network with %s' % init_type)
net.apply(init_func)
################################## Network structure ##################################
class ResnetBlockSpectralNorm(nn.Module):
def __init__(self, dim, padding_type, activation=nn.LeakyReLU(0.2), use_dropout=False):
super(ResnetBlockSpectralNorm, self).__init__()
......@@ -91,4 +97,160 @@ class ResnetBlockSpectralNorm(nn.Module):
def forward(self, x):
out = x + self.conv_block(x)
return out
\ No newline at end of file
return out
################################## Loss function ##################################
class HingeLossD(nn.Module):
def __init__(self):
super(HingeLossD, self).__init__()
def forward(self, dis_fake, dis_real):
loss_real = torch.mean(F.relu(1. - dis_real))
loss_fake = torch.mean(F.relu(1. + dis_fake))
return loss_real + loss_fake
class HingeLossG(nn.Module):
def __init__(self):
super(HingeLossG, self).__init__()
def forward(self, dis_fake):
loss_fake = -torch.mean(dis_fake)
return loss_fake
class VGGLoss(nn.Module):
def __init__(self, gpu_id):
super(VGGLoss, self).__init__()
self.vgg = Vgg19()
if gpu_id != '-1' and len(gpu_id) == 1:
self.vgg.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
self.vgg = nn.DataParallel(self.vgg)
self.vgg.cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
################################## Evaluation ##################################
'''https://github.com/Po-Hsun-Su/pytorch-ssim
img1 = Variable(torch.rand(1, 1, 256, 256))
img2 = Variable(torch.rand(1, 1, 256, 256))
if torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
print(pytorch_ssim.ssim(img1, img2))
ssim_loss = pytorch_ssim.SSIM(window_size = 11)
print(ssim_loss(img1, img2))
'''
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
......@@ -11,11 +11,10 @@ import random
import torch
import torch.nn as nn
import time
from multiprocessing import Process, Queue
from util import mosaic,util,ffmpeg,filt,data
from util import mosaic,util,ffmpeg,filt,data,dataloader
from util import image_processing as impro
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model,BVDNet,model_util
from models import BVDNet,model_util
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
......@@ -26,13 +25,15 @@ opt.parser.add_argument('--N',type=int,default=2, help='The input tensor shape i
opt.parser.add_argument('--S',type=int,default=3, help='Stride of 3 frames')
# opt.parser.add_argument('--T',type=int,default=7, help='T = 2N+1')
opt.parser.add_argument('--M',type=int,default=100, help='How many frames read from each videos')
opt.parser.add_argument('--lr',type=float,default=0.001, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.9, help='')
opt.parser.add_argument('--beta2',type=float,default=0.999, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--lambda_VGG',type=float,default=0.1, help='')
opt.parser.add_argument('--lambda_L2',type=float,default=100, help='')
opt.parser.add_argument('--lambda_VGG',type=float,default=1, help='')
opt.parser.add_argument('--lambda_GAN',type=float,default=1, help='')
opt.parser.add_argument('--load_thread',type=int,default=4, help='number of thread for loading data')
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
......@@ -45,134 +46,6 @@ opt.parser.add_argument('--showresult_freq',type=int,default=1000, help='')
opt.parser.add_argument('--showresult_num',type=int,default=4, help='')
opt.parser.add_argument('--psnr_freq',type=int,default=100, help='')
class TrainVideoLoader(object):
"""docstring for VideoLoader
Load a single video(Converted to images)
How to use:
1.Init TrainVideoLoader as loader
2.Get data by loader.ori_stream
3.loader.next() to get next stream
"""
def __init__(self, opt, video_dir, test_flag=False):
super(TrainVideoLoader, self).__init__()
self.opt = opt
self.test_flag = test_flag
self.video_dir = video_dir
self.t = 0
self.n_iter = self.opt.M -self.opt.S*(self.opt.T+1)
self.transform_params = data.get_transform_params()
self.ori_load_pool = []
self.mosaic_load_pool = []
self.previous_pred = None
feg_ori = impro.imread(os.path.join(video_dir,'origin_image','00001.jpg'),loadsize=self.opt.loadsize,rgb=True)
feg_mask = impro.imread(os.path.join(video_dir,'mask','00001.png'),mod='gray',loadsize=self.opt.loadsize)
self.mosaic_size,self.mod,self.rect_rat,self.feather = mosaic.get_random_parameter(feg_ori,feg_mask)
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
#Init load pool
for i in range(self.opt.S*self.opt.T):
#print(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'))
_ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,opt.finesize,self.transform_params)
self.ori_load_pool.append(self.normalize(_ori_img))
self.mosaic_load_pool.append(self.normalize(_mosaic_img))
self.ori_load_pool = np.array(self.ori_load_pool)
self.mosaic_load_pool = np.array(self.mosaic_load_pool)
#Init frist stream
self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
# stream B,T,H,W,C -> B,C,T,H,W
self.ori_stream = self.ori_stream.reshape (1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
#Init frist previous frame
self.previous_pred = self.ori_load_pool[self.opt.S*self.opt.N-1].copy()
# previous B,C,H,W
self.previous_pred = self.previous_pred.reshape(1,opt.finesize,opt.finesize,3).transpose((0,3,1,2))
def normalize(self,data):
'''
normalize to -1 ~ 1
'''
return (data.astype(np.float32)/255.0-0.5)/0.5
def anti_normalize(self,data):
return np.clip((data*0.5+0.5)*255,0,255).astype(np.uint8)
def next(self):
if self.t != 0:
self.previous_pred = None
self.ori_load_pool [:self.opt.S*self.opt.T-1] = self.ori_load_pool [1:self.opt.S*self.opt.T]
self.mosaic_load_pool[:self.opt.S*self.opt.T-1] = self.mosaic_load_pool[1:self.opt.S*self.opt.T]
#print(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'))
_ori_img = impro.imread(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(self.video_dir,'mask','%05d' % (self.opt.S*self.opt.T+self.t)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,opt.finesize,self.transform_params)
_ori_img,_mosaic_img = self.normalize(_ori_img),self.normalize(_mosaic_img)
self.ori_load_pool [self.opt.S*self.opt.T-1] = _ori_img
self.mosaic_load_pool[self.opt.S*self.opt.T-1] = _mosaic_img
self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
# stream B,T,H,W,C -> B,C,T,H,W
self.ori_stream = self.ori_stream.reshape (1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
self.t += 1
class DataLoader(object):
"""DataLoader"""
def __init__(self, opt, videolist, test_flag=False):
super(DataLoader, self).__init__()
self.videolist = []
self.opt = opt
self.test_flag = test_flag
for i in range(self.opt.n_epoch):
self.videolist += videolist
random.shuffle(self.videolist)
self.each_video_n_iter = self.opt.M -self.opt.S*(self.opt.T+1)
self.n_iter = len(self.videolist)//self.opt.load_thread//self.opt.batchsize*self.each_video_n_iter*self.opt.load_thread
self.queue = Queue(self.opt.load_thread)
self.ori_stream = np.zeros((self.opt.batchsize,3,self.opt.T,self.opt.finesize,self.opt.finesize),dtype=np.float32)# B,C,T,H,W
self.mosaic_stream = np.zeros((self.opt.batchsize,3,self.opt.T,self.opt.finesize,self.opt.finesize),dtype=np.float32)# B,C,T,H,W
self.previous_pred = np.zeros((self.opt.batchsize,3,self.opt.finesize,self.opt.finesize),dtype=np.float32)
self.load_init()
def load(self,videolist):
for load_video_iter in range(len(videolist)//self.opt.batchsize):
iter_videolist = videolist[load_video_iter*self.opt.batchsize:(load_video_iter+1)*self.opt.batchsize]
videoloaders = [TrainVideoLoader(self.opt,os.path.join(self.opt.dataset,iter_videolist[i]),self.test_flag) for i in range(self.opt.batchsize)]
for each_video_iter in range(self.each_video_n_iter):
for i in range(self.opt.batchsize):
self.ori_stream[i] = videoloaders[i].ori_stream
self.mosaic_stream[i] = videoloaders[i].mosaic_stream
if each_video_iter == 0:
self.previous_pred[i] = videoloaders[i].previous_pred
videoloaders[i].next()
if each_video_iter == 0:
self.queue.put([self.ori_stream.copy(),self.mosaic_stream.copy(),self.previous_pred])
else:
self.queue.put([self.ori_stream.copy(),self.mosaic_stream.copy(),None])
def load_init(self):
ptvn = len(self.videolist)//self.opt.load_thread #pre_thread_video_num
for i in range(self.opt.load_thread):
p = Process(target=self.load,args=(self.videolist[i*ptvn:(i+1)*ptvn],))
p.daemon = True
p.start()
def get_data(self):
return self.queue.get()
'''
--------------------------Init--------------------------
'''
......@@ -186,21 +59,21 @@ util.makedirs(dir_checkpoint)
localtime = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
tensorboard_savedir = os.path.join('checkpoints/tensorboard',localtime+'_'+opt.savename)
TBGlobalWriter = SummaryWriter(tensorboard_savedir)
net = BVDNet.BVDNet(opt.N)
'''
--------------------------Init Network--------------------------
'''
if opt.gpu_id != '-1' and len(opt.gpu_id) == 1:
torch.backends.cudnn.benchmark = True
net.cuda()
elif opt.gpu_id != '-1' and len(opt.gpu_id) > 1:
torch.backends.cudnn.benchmark = True
net = nn.DataParallel(net)
net.cuda()
netG = BVDNet.define_G(opt.N,gpu_id=opt.gpu_id)
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossf_L1 = nn.L1Loss()
lossf_VGG = BVDNet.VGGLoss([opt.gpu_id])
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
lossf_L2 = nn.MSELoss()
lossf_VGG = model_util.VGGLoss(opt.gpu_id)
'''
--------------------------Init DataLoader--------------------------
'''
videolist_tmp = os.listdir(opt.dataset)
videolist = []
for video in videolist_tmp:
......@@ -211,33 +84,36 @@ sorted(videolist)
videolist_train = videolist[:int(len(videolist)*0.8)].copy()
videolist_eval = videolist[int(len(videolist)*0.8):].copy()
dataloader_train = DataLoader(opt, videolist_train)
dataloader_eval = DataLoader(opt, videolist_eval)
Videodataloader_train = dataloader.VideoDataLoader(opt, videolist_train)
Videodataloader_eval = dataloader.VideoDataLoader(opt, videolist_eval)
'''
--------------------------Train--------------------------
'''
previous_predframe_tmp = 0
for train_iter in range(dataloader_train.n_iter):
for train_iter in range(Videodataloader_train.n_iter):
t_start = time.time()
# train
ori_stream,mosaic_stream,previous_frame = dataloader_train.get_data()
ori_stream,mosaic_stream,previous_frame = Videodataloader_train.get_data()
ori_stream = data.to_tensor(ori_stream, opt.gpu_id)
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
if previous_frame is None:
previous_frame = data.to_tensor(previous_predframe_tmp, opt.gpu_id)
else:
previous_frame = data.to_tensor(previous_frame, opt.gpu_id)
optimizer.zero_grad()
out = net(mosaic_stream,previous_frame)
loss_L1 = lossf_L1(out,ori_stream[:,:,opt.N])
optimizer_G.zero_grad()
out = netG(mosaic_stream,previous_frame)
loss_L2 = lossf_L2(out,ori_stream[:,:,opt.N]) * opt.lambda_L2
loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
TBGlobalWriter.add_scalars('loss/train', {'L1':loss_L1.item(),'VGG':loss_VGG.item()}, train_iter)
loss = loss_L1+loss_VGG
TBGlobalWriter.add_scalars('loss/train', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
loss = loss_L2+loss_VGG
loss.backward()
optimizer.step()
optimizer_G.step()
previous_predframe_tmp = out.detach().cpu().numpy()
# save network
if train_iter%opt.save_freq == 0 and train_iter != 0:
model_util.save(net, os.path.join('checkpoints',opt.savename,str(train_iter)+'.pth'), opt.gpu_id)
model_util.save(netG, os.path.join('checkpoints',opt.savename,str(train_iter)+'.pth'), opt.gpu_id)
# psnr
if train_iter%opt.psnr_freq ==0:
......@@ -254,10 +130,12 @@ for train_iter in range(dataloader_train.n_iter):
data.tensor2im(ori_stream[:,:,opt.N],rgb2bgr = False,batch_index=i)]
show_img = impro.splice(show_imgs, (opt.showresult_num,3))
TBGlobalWriter.add_image('train', show_img,train_iter,dataformats='HWC')
# eval
'''
--------------------------Eval--------------------------
'''
if (train_iter)%5 ==0:
ori_stream,mosaic_stream,previous_frame = dataloader_eval.get_data()
ori_stream,mosaic_stream,previous_frame = Videodataloader_eval.get_data()
ori_stream = data.to_tensor(ori_stream, opt.gpu_id)
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
if previous_frame is None:
......@@ -265,10 +143,10 @@ for train_iter in range(dataloader_train.n_iter):
else:
previous_frame = data.to_tensor(previous_frame, opt.gpu_id)
with torch.no_grad():
out = net(mosaic_stream,previous_frame)
loss_L1 = lossf_L1(out,ori_stream[:,:,opt.N])
out = netG(mosaic_stream,previous_frame)
loss_L2 = lossf_L2(out,ori_stream[:,:,opt.N])
loss_VGG = lossf_VGG(out,ori_stream[:,:,opt.N]) * opt.lambda_VGG
TBGlobalWriter.add_scalars('loss/eval', {'L1':loss_L1.item(),'VGG':loss_VGG.item()}, train_iter)
TBGlobalWriter.add_scalars('loss/eval', {'L2':loss_L2.item(),'VGG':loss_VGG.item()}, train_iter)
previous_predframe_tmp = out.detach().cpu().numpy()
#psnr
......@@ -277,7 +155,7 @@ for train_iter in range(dataloader_train.n_iter):
for i in range(len(out)):
psnr += impro.psnr(data.tensor2im(out,batch_index=i), data.tensor2im(ori_stream[:,:,opt.N],batch_index=i))
TBGlobalWriter.add_scalars('psnr', {'eval':psnr/len(out)}, train_iter)
#show
if train_iter % opt.showresult_freq == 0:
show_imgs = []
for i in range(opt.showresult_num):
......@@ -287,11 +165,13 @@ for train_iter in range(dataloader_train.n_iter):
show_img = impro.splice(show_imgs, (opt.showresult_num,3))
TBGlobalWriter.add_image('eval', show_img,train_iter,dataformats='HWC')
t_end = time.time()
print('iter:{0:d} t:{1:.2f} l1:{2:.4f} vgg:{3:.4f} psnr:{4:.2f}'.format(train_iter,t_end-t_start,
loss_L1.item(),loss_VGG.item(),psnr/len(out)) )
print('iter:{0:d} t:{1:.2f} L2:{2:.4f} vgg:{3:.4f} psnr:{4:.2f}'.format(train_iter,t_end-t_start,
loss_L2.item(),loss_VGG.item(),psnr/len(out)) )
t_strat = time.time()
# test
'''
--------------------------Test--------------------------
'''
if train_iter % opt.showresult_freq == 0 and os.path.isdir(opt.dataset_test):
show_imgs = []
videos = os.listdir(opt.dataset_test)
......@@ -309,7 +189,7 @@ for train_iter in range(dataloader_train.n_iter):
mosaic_stream = data.to_tensor(mosaic_stream, opt.gpu_id)
previous = data.im2tensor(previous,bgr2rgb = False, gpu_id = opt.gpu_id,use_transform = False, is0_1 = False)
with torch.no_grad():
out = net(mosaic_stream,previous)
out = netG(mosaic_stream,previous)
show_imgs+= [data.tensor2im(mosaic_stream[:,:,opt.N],rgb2bgr = False),data.tensor2im(out,rgb2bgr = False)]
show_img = impro.splice(show_imgs, (len(videos),2))
......
......@@ -5,7 +5,7 @@ import torch
import torchvision.transforms as transforms
import cv2
from . import image_processing as impro
from . import mosaic
from . import degradater
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
......@@ -75,51 +75,6 @@ def shuffledata(data,target):
np.random.set_state(state)
np.random.shuffle(target)
# def random_transform_video(src,target,finesize,N):
# #random blur
# if random.random()<0.2:
# h,w = src.shape[:2]
# src = src[:8*(h//8),:8*(w//8)]
# Q_ran = random.randint(1,15)
# src[:,:,:3*N] = impro.dctblur(src[:,:,:3*N],Q_ran)
# target = impro.dctblur(target,Q_ran)
# #random crop
# h,w = target.shape[:2]
# h_move = int((h-finesize)*random.random())
# w_move = int((w-finesize)*random.random())
# target = target[h_move:h_move+finesize,w_move:w_move+finesize,:]
# src = src[h_move:h_move+finesize,w_move:w_move+finesize,:]
# #random flip
# if random.random()<0.5:
# src = src[:,::-1,:]
# target = target[:,::-1,:]
# #random color
# alpha = random.uniform(-0.1,0.1)
# beta = random.uniform(-0.1,0.1)
# b = random.uniform(-0.05,0.05)
# g = random.uniform(-0.05,0.05)
# r = random.uniform(-0.05,0.05)
# for i in range(N):
# src[:,:,i*3:(i+1)*3] = impro.color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
# target = impro.color_adjust(target,alpha,beta,b,g,r)
# #random resize blur
# if random.random()<0.5:
# interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4]
# size_ran = random.uniform(0.7,1.5)
# interpolation_up = interpolations[random.randint(0,2)]
# interpolation_down =interpolations[random.randint(0,2)]
# tmp = cv2.resize(src[:,:,:3*N], (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
# src[:,:,:3*N] = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
# tmp = cv2.resize(target, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolation_up)
# target = cv2.resize(tmp, (finesize,finesize),interpolation=interpolation_down)
# return src,target
def random_transform_single_mask(img,out_shape):
out_h,out_w = out_shape
......@@ -138,40 +93,27 @@ def random_transform_single_mask(img,out_shape):
return img
def get_transform_params():
scale_flag = np.random.random()<0.2
crop_flag = True
rotat_flag = np.random.random()<0.2
color_flag = True
flip_flag = np.random.random()<0.2
blur_flag = np.random.random()<0.1
flag_dict = {'scale':scale_flag,'crop':crop_flag,'rotat':rotat_flag,'color':color_flag,
'flip':flip_flag,'blur':blur_flag}
degradate_flag = np.random.random()<0.5
flag_dict = {'crop':crop_flag,'rotat':rotat_flag,'color':color_flag,'flip':flip_flag,'degradate':degradate_flag}
scale_rate = np.random.uniform(0.9,1.1)
crop_rate = [np.random.random(),np.random.random()]
rotat_rate = np.random.random()
color_rate = [np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05),
np.random.uniform(-0.05,0.05),np.random.uniform(-0.05,0.05)]
flip_rate = np.random.random()
blur_rate = np.random.randint(1,15)
rate_dict = {'scale':scale_rate,'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,
'flip':flip_rate,'blur':blur_rate}
degradate_params = degradater.get_random_degenerate_params(mod='weaker_1')
rate_dict = {'crop':crop_rate,'rotat':rotat_rate,'color':color_rate,'flip':flip_rate,'degradate':degradate_params}
return {'flag':flag_dict,'rate':rate_dict}
def random_transform_single_image(img,finesize,params=None,test_flag = False):
if params is None:
params = get_transform_params()
if test_flag:
params['flag']['scale'] = False
if params['flag']['scale']:
h,w = img.shape[:2]
loadsize = min((h,w))
a = (float(h)/float(w))*params['rate']['scale']
if h<w:
img = cv2.resize(img, (int(loadsize/a),loadsize))
else:
img = cv2.resize(img, (loadsize,int(loadsize*a)))
if params['flag']['crop']:
h,w = img.shape[:2]
h_move = int((h-finesize)*params['rate']['crop'][0])
......@@ -193,8 +135,8 @@ def random_transform_single_image(img,finesize,params=None,test_flag = False):
if params['flag']['flip']:
img = img[:,::-1,:]
if params['flag']['blur']:
img = impro.dctblur(img,params['rate']['blur'])
if params['flag']['degradate']:
img = degradater.degradate(img,params['rate']['degradate'])
#check shape
if img.shape[0]!= finesize or img.shape[1]!= finesize:
......@@ -250,11 +192,6 @@ def random_transform_pair_image(img,mask,finesize,test_flag = False):
if random.random()<0.5:
img = impro.dctblur(img,random.randint(1,15))
# interpolations = [cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_LANCZOS4]
# size_ran = random.uniform(0.7,1.5)
# img = cv2.resize(img, (int(finesize*size_ran),int(finesize*size_ran)),interpolation=interpolations[random.randint(0,2)])
# img = cv2.resize(img, (finesize,finesize),interpolation=interpolations[random.randint(0,2)])
#check shape
if img.shape[0]!= finesize or img.shape[1]!= finesize or mask.shape[0]!= finesize or mask.shape[1]!= finesize:
img = cv2.resize(img,(finesize,finesize))
......@@ -262,31 +199,6 @@ def random_transform_pair_image(img,mask,finesize,test_flag = False):
print('warning! shape error.')
return img,mask
# def load_train_video(videoname,img_index,opt):
# N = opt.N
# input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
# # this frame
# this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
# input_img[:,:,-1] = this_mask
# #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
# ground_true = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
# mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask)
# start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3)
# # merge other frame
# for i in range(0,N):
# img = impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
# mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
# img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i])
# input_img[:,:,i*3:(i+1)*3] = img_mosaic
# # to tensor
# input_img,ground_true = random_transform_video(input_img,ground_true,opt.finesize,N)
# input_img = im2tensor(input_img,bgr2rgb=False,gpu_id=-1,use_transform = False,is0_1=False)
# ground_true = im2tensor(ground_true,bgr2rgb=False,gpu_id=-1,use_transform = False,is0_1=False)
# return input_img,ground_true
def showresult(img1,img2,img3,name,is0_1 = False):
size = img1.shape[3]
showimg=np.zeros((size,size*3,3))
......
import os
import random
import numpy as np
from multiprocessing import Process, Queue
from . import image_processing as impro
from . import mosaic,data
class VideoLoader(object):
"""docstring for VideoLoader
Load a single video(Converted to images)
How to use:
1.Init VideoLoader as loader
2.Get data by loader.ori_stream
3.loader.next() to get next stream
"""
def __init__(self, opt, video_dir, test_flag=False):
super(VideoLoader, self).__init__()
self.opt = opt
self.test_flag = test_flag
self.video_dir = video_dir
self.t = 0
self.n_iter = self.opt.M -self.opt.S*(self.opt.T+1)
self.transform_params = data.get_transform_params()
self.ori_load_pool = []
self.mosaic_load_pool = []
self.previous_pred = None
feg_ori = impro.imread(os.path.join(video_dir,'origin_image','00001.jpg'),loadsize=self.opt.loadsize,rgb=True)
feg_mask = impro.imread(os.path.join(video_dir,'mask','00001.png'),mod='gray',loadsize=self.opt.loadsize)
self.mosaic_size,self.mod,self.rect_rat,self.feather = mosaic.get_random_parameter(feg_ori,feg_mask)
self.startpos = [random.randint(0,self.mosaic_size),random.randint(0,self.mosaic_size)]
#Init load pool
for i in range(self.opt.S*self.opt.T):
#print(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'))
_ori_img = impro.imread(os.path.join(video_dir,'origin_image','%05d' % (i+1)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(video_dir,'mask','%05d' % (i+1)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,opt.finesize,self.transform_params)
self.ori_load_pool.append(self.normalize(_ori_img))
self.mosaic_load_pool.append(self.normalize(_mosaic_img))
self.ori_load_pool = np.array(self.ori_load_pool)
self.mosaic_load_pool = np.array(self.mosaic_load_pool)
#Init frist stream
self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
# stream B,T,H,W,C -> B,C,T,H,W
self.ori_stream = self.ori_stream.reshape (1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,opt.finesize,opt.finesize,3).transpose((0,4,1,2,3))
#Init frist previous frame
self.previous_pred = self.ori_load_pool[self.opt.S*self.opt.N-1].copy()
# previous B,C,H,W
self.previous_pred = self.previous_pred.reshape(1,opt.finesize,opt.finesize,3).transpose((0,3,1,2))
def normalize(self,data):
'''
normalize to -1 ~ 1
'''
return (data.astype(np.float32)/255.0-0.5)/0.5
def anti_normalize(self,data):
return np.clip((data*0.5+0.5)*255,0,255).astype(np.uint8)
def next(self):
if self.t != 0:
self.previous_pred = None
self.ori_load_pool [:self.opt.S*self.opt.T-1] = self.ori_load_pool [1:self.opt.S*self.opt.T]
self.mosaic_load_pool[:self.opt.S*self.opt.T-1] = self.mosaic_load_pool[1:self.opt.S*self.opt.T]
#print(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'))
_ori_img = impro.imread(os.path.join(self.video_dir,'origin_image','%05d' % (self.opt.S*self.opt.T+self.t)+'.jpg'),loadsize=self.opt.loadsize,rgb=True)
_mask = impro.imread(os.path.join(self.video_dir,'mask','%05d' % (self.opt.S*self.opt.T+self.t)+'.png' ),mod='gray',loadsize=self.opt.loadsize)
_mosaic_img = mosaic.addmosaic_base(_ori_img, _mask, self.mosaic_size,0, self.mod,self.rect_rat,self.feather,self.startpos)
_ori_img = data.random_transform_single_image(_ori_img,self.opt.finesize,self.transform_params)
_mosaic_img = data.random_transform_single_image(_mosaic_img,self.opt.finesize,self.transform_params)
_ori_img,_mosaic_img = self.normalize(_ori_img),self.normalize(_mosaic_img)
self.ori_load_pool [self.opt.S*self.opt.T-1] = _ori_img
self.mosaic_load_pool[self.opt.S*self.opt.T-1] = _mosaic_img
self.ori_stream = self.ori_load_pool [np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
self.mosaic_stream = self.mosaic_load_pool[np.linspace(0, (self.opt.T-1)*self.opt.S,self.opt.T,dtype=np.int64)].copy()
# stream B,T,H,W,C -> B,C,T,H,W
self.ori_stream = self.ori_stream.reshape (1,self.opt.T,self.opt.finesize,self.opt.finesize,3).transpose((0,4,1,2,3))
self.mosaic_stream = self.mosaic_stream.reshape(1,self.opt.T,self.opt.finesize,self.opt.finesize,3).transpose((0,4,1,2,3))
self.t += 1
class VideoDataLoader(object):
"""VideoDataLoader"""
def __init__(self, opt, videolist, test_flag=False):
super(VideoDataLoader, self).__init__()
self.videolist = []
self.opt = opt
self.test_flag = test_flag
for i in range(self.opt.n_epoch):
self.videolist += videolist
random.shuffle(self.videolist)
self.each_video_n_iter = self.opt.M -self.opt.S*(self.opt.T+1)
self.n_iter = len(self.videolist)//self.opt.load_thread//self.opt.batchsize*self.each_video_n_iter*self.opt.load_thread
self.queue = Queue(self.opt.load_thread)
self.ori_stream = np.zeros((self.opt.batchsize,3,self.opt.T,self.opt.finesize,self.opt.finesize),dtype=np.float32)# B,C,T,H,W
self.mosaic_stream = np.zeros((self.opt.batchsize,3,self.opt.T,self.opt.finesize,self.opt.finesize),dtype=np.float32)# B,C,T,H,W
self.previous_pred = np.zeros((self.opt.batchsize,3,self.opt.finesize,self.opt.finesize),dtype=np.float32)
self.load_init()
def load(self,videolist):
for load_video_iter in range(len(videolist)//self.opt.batchsize):
iter_videolist = videolist[load_video_iter*self.opt.batchsize:(load_video_iter+1)*self.opt.batchsize]
videoloaders = [VideoLoader(self.opt,os.path.join(self.opt.dataset,iter_videolist[i]),self.test_flag) for i in range(self.opt.batchsize)]
for each_video_iter in range(self.each_video_n_iter):
for i in range(self.opt.batchsize):
self.ori_stream[i] = videoloaders[i].ori_stream
self.mosaic_stream[i] = videoloaders[i].mosaic_stream
if each_video_iter == 0:
self.previous_pred[i] = videoloaders[i].previous_pred
videoloaders[i].next()
if each_video_iter == 0:
self.queue.put([self.ori_stream.copy(),self.mosaic_stream.copy(),self.previous_pred])
else:
self.queue.put([self.ori_stream.copy(),self.mosaic_stream.copy(),None])
def load_init(self):
ptvn = len(self.videolist)//self.opt.load_thread #pre_thread_video_num
for i in range(self.opt.load_thread):
p = Process(target=self.load,args=(self.videolist[i*ptvn:(i+1)*ptvn],))
p.daemon = True
p.start()
def get_data(self):
return self.queue.get()
\ No newline at end of file
'''
https://github.com/sonack/GFRNet_pytorch_new
'''
import random
import cv2
import numpy as np
def gaussian_blur(img, sigma=3, size=13):
if sigma > 0:
if isinstance(size, int):
size = (size, size)
img = cv2.GaussianBlur(img, size, sigma)
return img
def down(img, scale, shape):
if scale > 1:
h, w, _ = shape
scaled_h, scaled_w = int(h / scale), int(w / scale)
img = cv2.resize(img, (scaled_w, scaled_h), interpolation = cv2.INTER_CUBIC)
return img
def up(img, scale, shape):
if scale > 1:
h, w, _ = shape
img = cv2.resize(img, (w, h), interpolation = cv2.INTER_CUBIC)
return img
def awgn(img, level):
if level > 0:
noise = np.random.randn(*img.shape) * level
img = (img + noise).clip(0,255).astype(np.uint8)
return img
def jpeg_compressor(img,quality):
if quality > 0: # 0 indicating no lossy compression (i.e losslessly compression)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
img = cv2.imdecode(cv2.imencode('.jpg', img, encode_param)[1], 1)
return img
def get_random_degenerate_params(mod='strong'):
'''
mod : strong | only_downsample | only_4x | weaker_1 | weaker_2
'''
params = {}
gaussianBlur_size_list = list(range(3,14,2))
if mod == 'strong':
gaussianBlur_sigma_list = [1 + x for x in range(3)]
gaussianBlur_sigma_list += [0]
downsample_scale_list = [1 + x * 0.1 for x in range(0,71)]
awgn_level_list = list(range(1, 8, 1))
jpeg_quality_list = list(range(10, 41, 1))
jpeg_quality_list += int(len(jpeg_quality_list) * 0.33) * [0]
elif mod == 'only_downsample':
gaussianBlur_sigma_list = [0]
downsample_scale_list = [1 + x * 0.1 for x in range(0,71)]
awgn_level_list = [0]
jpeg_quality_list = [0]
elif mod == 'only_4x':
gaussianBlur_sigma_list = [0]
downsample_scale_list = [4]
awgn_level_list = [0]
jpeg_quality_list = [0]
elif mod == 'weaker_1': # 0.5 trigger prob
gaussianBlur_sigma_list = [1 + x for x in range(3)]
gaussianBlur_sigma_list += int(len(gaussianBlur_sigma_list)) * [0] # 1/2 trigger this degradation
downsample_scale_list = [1 + x * 0.1 for x in range(0,71)]
downsample_scale_list += int(len(downsample_scale_list)) * [1]
awgn_level_list = list(range(1, 8, 1))
awgn_level_list += int(len(awgn_level_list)) * [0]
jpeg_quality_list = list(range(10, 41, 1))
jpeg_quality_list += int(len(jpeg_quality_list)) * [0]
elif mod == 'weaker_2': # weaker than weaker_1, jpeg [20,40]
gaussianBlur_sigma_list = [1 + x for x in range(3)]
gaussianBlur_sigma_list += int(len(gaussianBlur_sigma_list)) * [0] # 1/2 trigger this degradation
downsample_scale_list = [1 + x * 0.1 for x in range(0,71)]
downsample_scale_list += int(len(downsample_scale_list)) * [1]
awgn_level_list = list(range(1, 8, 1))
awgn_level_list += int(len(awgn_level_list)) * [0]
jpeg_quality_list = list(range(20, 41, 1))
jpeg_quality_list += int(len(jpeg_quality_list)) * [0]
params['blur_sigma'] = random.choice(gaussianBlur_sigma_list)
params['blur_size'] = random.choice(gaussianBlur_size_list)
params['updown_scale'] = random.choice(downsample_scale_list)
params['awgn_level'] = random.choice(awgn_level_list)
params['jpeg_quality'] = random.choice(jpeg_quality_list)
return params
def degradate(img,params,jpeg_last = False):
shape = img.shape
if not params:
params = get_random_degenerate_params('original')
if jpeg_last:
img = gaussian_blur(img,params['blur_sigma'],params['blur_size'])
img = down(img,params['updown_scale'],shape)
img = awgn(img,params['awgn_level'])
img = up(img,params['updown_scale'],shape)
img = jpeg_compressor(img,params['jpeg_quality'])
else:
img = gaussian_blur(img,params['blur_sigma'],params['blur_size'])
img = down(img,params['updown_scale'],shape)
img = awgn(img,params['awgn_level'])
img = jpeg_compressor(img,params['jpeg_quality'])
img = up(img,params['updown_scale'],shape)
return img
\ No newline at end of file
......@@ -8,15 +8,6 @@ system_type = 'Linux'
if 'Windows' in platform.platform():
system_type = 'Windows'
DCT_Q = np.array([[8,16,19,22,26,27,29,34],
[16,16,22,24,27,29,34,37],
[19,22,26,27,29,34,34,38],
[22,22,26,27,29,34,37,40],
[22,26,27,29,32,35,40,48],
[26,27,29,32,35,40,48,58],
[26,27,29,34,38,46,56,59],
[27,29,35,38,46,56,69,83]])
def imread(file_path,mod = 'normal',loadsize = 0, rgb=False):
'''
mod: 'normal' | 'gray' | 'all'
......@@ -121,34 +112,6 @@ def makedataset(target_image,orgin_image):
img[0:256,0:256] = target_image[0:256,int(w/2-256/2):int(w/2+256/2)]
img[0:256,256:512] = orgin_image[0:256,int(w/2-256/2):int(w/2+256/2)]
return img
def block_dct_and_idct(g,QQF,QQF_16):
return cv2.idct(np.round(16.0*cv2.dct(g)/QQF)*QQF_16)
def image_dct_and_idct(I,QF):
h,w = I.shape
QQF = DCT_Q*QF
QQF_16 = QQF/16.0
for i in range(h//8):
for j in range(w//8):
I[i*8:(i+1)*8,j*8:(j+1)*8] = cv2.idct(np.round(16.0*cv2.dct(I[i*8:(i+1)*8,j*8:(j+1)*8])/QQF)*QQF_16)
#I[i*8:(i+1)*8,j*8:(j+1)*8] = block_dct_and_idct(I[i*8:(i+1)*8,j*8:(j+1)*8],QQF,QQF_16)
return I
def dctblur(img,Q):
'''
Q: 1~20, 1->best
'''
h,w = img.shape[:2]
img = img[:8*(h//8),:8*(w//8)]
img = img.astype(np.float32)
if img.ndim == 2:
img = image_dct_and_idct(img, Q)
if img.ndim == 3:
h,w,ch = img.shape
for i in range(ch):
img[:,:,i] = image_dct_and_idct(img[:,:,i], Q)
return (np.clip(img,0,255)).astype(np.uint8)
def find_mostlikely_ROI(mask):
contours,hierarchy=cv2.findContours(mask, cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
......@@ -215,19 +178,6 @@ def mask_area(mask):
area = 0
return area
def Q_lapulase(resImg):
'''
Evaluate image quality
score > 20 normal
score > 50 clear
'''
img2gray = cv2.cvtColor(resImg, cv2.COLOR_BGR2GRAY)
img2gray = resize(img2gray,512)
res = cv2.Laplacian(img2gray, cv2.CV_64F)
score = res.var()
return score
def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather):
img_fake = cv2.resize(img_fake,(size*2,size*2),interpolation=cv2.INTER_LANCZOS4)
if no_feather:
......@@ -257,6 +207,18 @@ def replace_mosaic(img_origin,img_fake,mask,x,y,size,no_feather):
return img_result
def Q_lapulase(resImg):
'''
Evaluate image quality
score > 20 normal
score > 50 clear
'''
img2gray = cv2.cvtColor(resImg, cv2.COLOR_BGR2GRAY)
img2gray = resize(img2gray,512)
res = cv2.Laplacian(img2gray, cv2.CV_64F)
score = res.var()
return score
def psnr(img1,img2):
mse = np.mean((img1/255.0-img2/255.0)**2)
if mse < 1e-10:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册