提交 fa59f69e 编写于 作者: L LielinJiang

add deepremaster, fix some bug

上级 6a5109c5
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(cur_path)
import paddle
import paddle.nn as nn
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
import subprocess
import utils
from remasternet import NetworkR, NetworkC
from paddle.incubate.hapi.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster without colorization')
parser.add_argument('--mindim', type=int, default='360', help='Length of minimum image edges')
class DeepReasterPredictor:
def __init__(self, input, output, weight_path=None, colorization=False, reference_dir=None, mindim=360):
self.input = input
self.output = os.path.join(output, 'DeepRemaster')
self.colorization = colorization
self.reference_dir = reference_dir
self.mindim = mindim
if weight_path is None:
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path)
state_dict, _ = paddle.load(weight_path)
self.modelR = NetworkR()
self.modelR.load_dict(state_dict['modelR'])
self.modelR.eval()
if colorization:
self.modelC = NetworkC()
self.modelC.load_dict(state_dict['modelC'])
self.modelC.eval()
def run(self):
outputdir = self.output
outputdir_in = os.path.join(outputdir, 'input/')
os.makedirs( outputdir_in, exist_ok=True )
outputdir_out = os.path.join(outputdir, 'output/')
os.makedirs( outputdir_out, exist_ok=True )
# Prepare reference images
if self.colorization:
if self.reference_dir is not None:
import glob
ext_list = ['png','jpg','bmp']
reference_files = []
for ext in ext_list:
reference_files += glob.glob( self.reference_dir+'/*.'+ext, recursive=True )
aspect_mean = 0
minedge_dim = 256
refs = []
for v in reference_files:
refimg = Image.open( v ).convert('RGB')
w, h = refimg.size
aspect_mean += w/h
refs.append( refimg )
aspect_mean /= len(reference_files)
target_w = int(256*aspect_mean) if aspect_mean>1 else 256
target_h = 256 if aspect_mean>=1 else int(256/aspect_mean)
refimgs = []
for i, v in enumerate(refs):
refimg = utils.addMergin( v, target_w=target_w, target_h=target_h )
refimg = np.array(refimg).astype('float32').transpose(2, 0, 1) / 255.0
refimgs.append(refimg)
refimgs = paddle.to_tensor(np.array(refimgs).astype('float32'))
refimgs = paddle.unsqueeze(refimgs, 0)
# Load video
cap = cv2.VideoCapture( self.input )
nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
minwh = min(v_w,v_h)
scale = 1
if minwh != self.mindim:
scale = self.mindim / minwh
t_w = round(v_w*scale/16.)*16
t_h = round(v_h*scale/16.)*16
fps = cap.get(cv2.CAP_PROP_FPS)
pbar = tqdm(total=nframes)
block = 5
# Process
with paddle.no_grad():
it = 0
while True:
frame_pos = it*block
if frame_pos >= nframes:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
if block >= nframes-frame_pos:
proc_g = nframes-frame_pos
else:
proc_g = block
input = None
gtC = None
for i in range(proc_g):
index = frame_pos + i
_, frame = cap.read()
frame = cv2.resize(frame, (t_w, t_h))
nchannels = frame.shape[2]
if nchannels == 1 or self.colorization:
frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cv2.imwrite(outputdir_in+'%07d.png'%index, frame_l)
frame_l = paddle.to_tensor(frame_l.astype('float32'))
frame_l = paddle.reshape(frame_l, [frame_l.shape[0], frame_l.shape[1], 1])
frame_l = paddle.transpose(frame_l, [2, 0, 1])
frame_l /= 255.
frame_l = paddle.reshape(frame_l, [1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]])
elif nchannels == 3:
cv2.imwrite(outputdir_in+'%07d.png'%index, frame)
frame = frame[:,:,::-1] ## BGR -> RGB
frame_l, frame_ab = utils.convertRGB2LABTensor( frame )
frame_l = frame_l.transpose([2, 0, 1])
frame_ab = frame_ab.transpose([2, 0, 1])
frame_l = frame_l.reshape([1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]])
frame_ab = frame_ab.reshape([1, frame_ab.shape[0], 1, frame_ab.shape[1], frame_ab.shape[2]])
if input is not None:
paddle.concat( (input, frame_l), 2 )
input = frame_l if i==0 else paddle.concat( (input, frame_l), 2 )
if nchannels==3 and not self.colorization:
gtC = frame_ab if i==0 else paddle.concat( (gtC, frame_ab), 2 )
input = paddle.to_tensor(input)
output_l = self.modelR( input ) # [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization]
if not self.colorization:
for i in range( proc_g ):
index = frame_pos + i
if nchannels==3:
out_l = output_l.detach()[0,:,i]
out_ab = gtC[0,:,i]
out = paddle.concat((out_l, out_ab),axis=0).detach().numpy().transpose((1, 2, 0))
out = Image.fromarray( np.uint8( utils.convertLAB2RGB( out )*255 ) )
out.save( outputdir_out+'%07d.png'%(index) )
else:
raise ValueError('channels of imag3 must be 3!')
# Perform colorization
else:
if self.reference_dir is None:
output_ab = self.modelC( output_l )
else:
output_ab = self.modelC( output_l, refimgs )
output_l = output_l.detach()
output_ab = output_ab.detach()
for i in range( proc_g ):
index = frame_pos + i
out_l = output_l[0,:,i,:,:]
out_c = output_ab[0,:,i,:,:]
output = paddle.concat((out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
output.save( outputdir_out+'%07d.png'%index )
it = it + 1
pbar.update(proc_g)
# Save result videos
outfile = os.path.join(outputdir, self.input.split('/')[-1].split('.')[0])
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (fps, outputdir_in, fps, outfile )
subprocess.call( cmd, shell=True )
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (fps, outputdir_out, fps, outfile )
subprocess.call( cmd, shell=True )
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( outfile, outfile, outfile )
subprocess.call( cmd, shell=True )
cap.release()
pbar.close()
return outputdir_out, '%s_out.mp4' % outfile
if __name__ == "__main__":
args = parser.parse_args()
paddle.disable_static()
predictor = DeepReasterPredictor(args.input, args.output, colorization=args.colorization,
reference_dir=args.reference_dir, mindim=args.mindim)
predictor.run()
\ No newline at end of file
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class TempConv(nn.Layer):
def __init__(self, in_planes, out_planes, kernel_size=(1,3,3), stride=(1,1,1), padding=(0,1,1) ):
super(TempConv, self).__init__()
self.conv3d = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
self.bn = nn.BatchNorm( out_planes )
def forward(self, x):
return F.elu( self.bn(self.conv3d(x)))
class Upsample(nn.Layer):
def __init__(self, in_planes, out_planes, scale_factor=(1,2,2)):
super(Upsample, self).__init__()
self.scale_factor = scale_factor
self.conv3d = nn.Conv3d( in_planes, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
self.bn = nn.BatchNorm( out_planes )
def forward(self, x):
out_size = x.shape[2:]
for i in range(3):
out_size[i] = self.scale_factor[i] * out_size[i]
return F.elu( self.bn( self.conv3d( F.interpolate(x, size=out_size, mode='trilinear', align_corners=False, data_format='NCDHW', align_mode=0))))
class UpsampleConcat(nn.Layer):
def __init__(self, in_planes_up, in_planes_flat, out_planes):
super(UpsampleConcat, self).__init__()
self.conv3d = TempConv( in_planes_up + in_planes_flat, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
def forward(self, x1, x2):
scale_factor=(1,2,2)
out_size = x1.shape[2:]
for i in range(3):
out_size[i] = scale_factor[i] * out_size[i]
x1 = F.interpolate(x1, size=out_size, mode='trilinear', align_corners=False, data_format='NCDHW', align_mode=0)
x = paddle.concat([x1, x2], axis=1)
return self.conv3d(x)
class SourceReferenceAttention(paddle.fluid.dygraph.Layer):
"""
Source-Reference Attention Layer
"""
def __init__(self, in_planes_s, in_planes_r):
"""
Parameters
----------
in_planes_s: int
Number of input source feature vector channels.
in_planes_r: int
Number of input reference feature vector channels.
"""
super(SourceReferenceAttention,self).__init__()
self.query_conv = nn.Conv3d(in_channels=in_planes_s,
out_channels=in_planes_s//8, kernel_size=1 )
self.key_conv = nn.Conv3d(in_channels=in_planes_r,
out_channels=in_planes_r//8, kernel_size=1 )
self.value_conv = nn.Conv3d(in_channels=in_planes_r,
out_channels=in_planes_r, kernel_size=1 )
self.gamma = self.create_parameter(shape=[1], dtype=self.query_conv.weight.dtype,
default_initializer=paddle.fluid.initializer.Constant(0.0))
def forward(self, source, reference):
s_batchsize, sC, sT, sH, sW = source.shape
r_batchsize, rC, rT, rH, rW = reference.shape
proj_query = paddle.reshape(self.query_conv(source), [s_batchsize,-1,sT*sH*sW])
proj_query = paddle.transpose(proj_query, [0, 2, 1])
proj_key = paddle.reshape(self.key_conv(reference), [r_batchsize,-1,rT*rW*rH])
energy = paddle.bmm( proj_query, proj_key )
attention = F.softmax(energy)
proj_value = paddle.reshape(self.value_conv(reference), [r_batchsize,-1,rT*rH*rW])
out = paddle.bmm(proj_value,paddle.transpose(attention, [0,2,1]))
out = paddle.reshape(out, [s_batchsize, sC, sT, sH, sW])
out = self.gamma*out + source
return out, attention
class NetworkR( nn.Layer ):
def __init__(self):
super(NetworkR, self).__init__()
self.layers = nn.Sequential(
nn.ReplicationPad3d((1,1,1,1,1,1)),
TempConv( 1, 64, kernel_size=(3,3,3), stride=(1,2,2), padding=(0,0,0) ),
TempConv( 64, 128, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 128, 128, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 128, 256, kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1) ),
TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ),
Upsample( 256, 128 ),
TempConv( 128, 64, kernel_size=(3,3,3), padding=(1,1,1) ),
TempConv( 64, 64, kernel_size=(3,3,3), padding=(1,1,1) ),
Upsample( 64, 16 ),
nn.Conv3d( 16, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
)
def forward(self, x):
return paddle.clip((x + paddle.fluid.layers.tanh( self.layers( ((x * 1).detach())-0.4462414 ) )), 0.0, 1.0)
class NetworkC( nn.Layer ):
def __init__(self):
super(NetworkC, self).__init__()
self.down1 = nn.Sequential(
nn.ReplicationPad3d((1,1,1,1,0,0)),
TempConv( 1, 64, stride=(1,2,2), padding=(0,0,0) ),
TempConv( 64, 128 ),
TempConv( 128, 128 ),
TempConv( 128, 256, stride=(1,2,2) ),
TempConv( 256, 256 ),
TempConv( 256, 256 ),
TempConv( 256, 512, stride=(1,2,2) ),
TempConv( 512, 512 ),
TempConv( 512, 512 )
)
self.flat = nn.Sequential(
TempConv( 512, 512 ),
TempConv( 512, 512 )
)
self.down2 = nn.Sequential(
TempConv( 512, 512, stride=(1,2,2) ),
TempConv( 512, 512 ),
)
self.stattn1 = SourceReferenceAttention( 512, 512 ) # Source-Reference Attention
self.stattn2 = SourceReferenceAttention( 512, 512 ) # Source-Reference Attention
self.selfattn1 = SourceReferenceAttention( 512, 512 ) # Self Attention
self.conv1 = TempConv( 512, 512 )
self.up1 = UpsampleConcat( 512, 512, 512 ) # 1/8
self.selfattn2 = SourceReferenceAttention( 512, 512 ) # Self Attention
self.conv2 = TempConv( 512, 256, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
self.up2 = nn.Sequential(
Upsample( 256, 128 ), # 1/4
TempConv( 128, 64, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
)
self.up3 = nn.Sequential(
Upsample( 64, 32 ), # 1/2
TempConv( 32, 16, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
)
self.up4 = nn.Sequential(
Upsample( 16, 8 ), # 1/1
nn.Conv3d( 8, 2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) )
)
self.reffeatnet1 = nn.Sequential(
TempConv( 3, 64, stride=(1,2,2) ),
TempConv( 64, 128 ),
TempConv( 128, 128 ),
TempConv( 128, 256, stride=(1,2,2) ),
TempConv( 256, 256 ),
TempConv( 256, 256 ),
TempConv( 256, 512, stride=(1,2,2) ),
TempConv( 512, 512 ),
TempConv( 512, 512 ),
)
self.reffeatnet2 = nn.Sequential(
TempConv( 512, 512, stride=(1,2,2) ),
TempConv( 512, 512 ),
TempConv( 512, 512 ),
)
def forward(self, x, x_refs=None):
x1 = self.down1( x - 0.4462414 )
if x_refs is not None:
x_refs = paddle.transpose(x_refs, [0, 2, 1, 3, 4]) # [B,T,C,H,W] --> [B,C,T,H,W]
reffeat = self.reffeatnet1( x_refs-0.48 )
x1, _ = self.stattn1( x1, reffeat )
x2 = self.flat( x1 )
out = self.down2( x1 )
if x_refs is not None:
reffeat2 = self.reffeatnet2( reffeat )
out, _ = self.stattn2( out, reffeat2 )
out = self.conv1( out )
out, _ = self.selfattn1( out, out )
out = self.up1( out, x2 )
out, _ = self.selfattn2( out, out )
out = self.conv2( out )
out = self.up2( out )
out = self.up3( out )
out = self.up4( out )
return F.sigmoid( out )
\ No newline at end of file
import paddle
from skimage import color
import numpy as np
from PIL import Image
def convertLAB2RGB( lab ):
lab[:, :, 0:1] = lab[:, :, 0:1] * 100 # [0, 1] -> [0, 100]
lab[:, :, 1:3] = np.clip(lab[:, :, 1:3] * 255 - 128, -100, 100) # [0, 1] -> [-128, 128]
rgb = color.lab2rgb( lab.astype(np.float64) )
return rgb
def convertRGB2LABTensor( rgb ):
lab = color.rgb2lab( np.asarray( rgb ) ) # RGB -> LAB L[0, 100] a[-127, 128] b[-128, 127]
ab = np.clip(lab[:, :, 1:3] + 128, 0, 255) # AB --> [0, 255]
ab = paddle.to_tensor(ab.astype('float32')) / 255.
L = lab[:, :, 0] * 2.55 # L --> [0, 255]
L = Image.fromarray( np.uint8( L ) )
L = paddle.to_tensor(np.array(L).astype('float32')[..., np.newaxis] / 255.0)
return L, ab
def addMergin(img, target_w, target_h, background_color=(0,0,0)):
width, height = img.size
if width==target_w and height==target_h:
return img
scale = max(target_w,target_h)/max(width, height)
width = int(width*scale/16.)*16
height = int(height*scale/16.)*16
img = img.resize((width, height), Image.BICUBIC)
xp = (target_w-width)//2
yp = (target_h-height)//2
result = Image.new(img.mode, (target_w, target_h), background_color)
result.paste(img, (xp, yp))
return result
......@@ -10,4 +10,4 @@ cd -
# proccess_order 使用模型的顺序
python tools/main.py \
--input input.mp4 --output output --proccess_order DAIN DeOldify EDVR
--input input.mp4 --output output --proccess_order DAIN DeepRemaster DeOldify EDVR
......@@ -5,23 +5,30 @@ import argparse
import paddle
from DAIN.predict import VideoFrameInterp
from DeepRemaster.predict import DeepReasterPredictor
from DeOldify.predict import DeOldifyPredictor
from EDVR.predict import EDVRPredictor
parser = argparse.ArgumentParser(description='Fix video')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--DAIN_weight', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--DeOldify_weight', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--DAIN_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--DeepRemaster_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--DeOldify_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to model weight')
# DAIN args
parser.add_argument('--time_step', type=float, default=0.5, help='choose the time steps')
# DeepRemaster args
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster with colorization')
parser.add_argument('--mindim', type=int, default=360, help='Length of minimum image edges')
#process order support model name:[DAIN, DeepRemaster, DeOldify, EDVR]
parser.add_argument('--proccess_order', type=str, default='none', nargs='+', help='Process order')
if __name__ == "__main__":
args = parser.parse_args()
print('args...', args)
orders = args.proccess_order
temp_video_path = None
......@@ -32,19 +39,21 @@ if __name__ == "__main__":
predictor = VideoFrameInterp(args.time_step, args.DAIN_weight,
temp_video_path, output_path=args.output)
frames_path, temp_video_path = predictor.run()
elif order == 'DeOldify':
print('frames:', frames_path)
print('video_path:', temp_video_path)
elif order == 'DeepRemaster':
paddle.disable_static()
predictor = DeepReasterPredictor(temp_video_path, args.output, weight_path=args.DeepRemaster_weight,
colorization=args.colorization, reference_dir=args.reference_dir, mindim=args.mindim)
frames_path, temp_video_path = predictor.run()
paddle.enable_static()
elif order == 'DeOldify':
paddle.disable_static()
predictor = DeOldifyPredictor(temp_video_path, args.output, weight_path=args.DeOldify_weight)
frames_path, temp_video_path = predictor.run()
print('frames:', frames_path)
print('video_path:', temp_video_path)
paddle.enable_static()
elif order == 'EDVR':
predictor = EDVRPredictor(temp_video_path, args.output, weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run()
print('frames:', frames_path)
print('video_path:', temp_video_path)
print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部