未验证 提交 b306aa73 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #16 from LielinJiang/adapt-to-2.0-api

Adapt to api 2.0 again
...@@ -11,7 +11,7 @@ from imageio import imread, imsave ...@@ -11,7 +11,7 @@ from imageio import imread, imsave
import cv2 import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
import networks import networks
from util import * from util import *
...@@ -19,6 +19,7 @@ from my_args import parser ...@@ -19,6 +19,7 @@ from my_args import parser
DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar' DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
def infer_engine(model_dir, def infer_engine(model_dir,
run_mode='fluid', run_mode='fluid',
batch_size=1, batch_size=1,
...@@ -91,7 +92,6 @@ class VideoFrameInterp(object): ...@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
self.exe, self.program, self.fetch_targets = executor(model_path, self.exe, self.program, self.fetch_targets = executor(model_path,
use_gpu=use_gpu) use_gpu=use_gpu)
def run(self): def run(self):
frame_path_input = os.path.join(self.output_path, 'frames-input') frame_path_input = os.path.join(self.output_path, 'frames-input')
frame_path_interpolated = os.path.join(self.output_path, frame_path_interpolated = os.path.join(self.output_path,
...@@ -272,7 +272,7 @@ class VideoFrameInterp(object): ...@@ -272,7 +272,7 @@ class VideoFrameInterp(object):
os.remove(video_pattern_output) os.remove(video_pattern_output)
frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output, frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output,
r2) r2)
return frame_pattern_combined, video_pattern_output return frame_pattern_combined, video_pattern_output
......
...@@ -15,15 +15,19 @@ from PIL import Image ...@@ -15,15 +15,19 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from paddle import fluid from paddle import fluid
from model import build_model from model import build_model
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify') parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path', type=str, default='none', help='Path to the reference image directory') parser.add_argument('--weight_path',
type=str,
default='none',
help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r): def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [
...@@ -56,9 +60,9 @@ class DeOldifyPredictor(): ...@@ -56,9 +60,9 @@ class DeOldifyPredictor():
def norm(self, img, render_factor=32, render_base=16): def norm(self, img, render_factor=32, render_base=16):
target_size = render_factor * render_base target_size = render_factor * render_base
img = img.resize((target_size, target_size), resample=Image.BILINEAR) img = img.resize((target_size, target_size), resample=Image.BILINEAR)
img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0 img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -69,13 +73,13 @@ class DeOldifyPredictor(): ...@@ -69,13 +73,13 @@ class DeOldifyPredictor():
def denorm(self, img): def denorm(self, img):
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
img *= img_std img *= img_std
img += img_mean img += img_mean
img = img.transpose((1, 2, 0)) img = img.transpose((1, 2, 0))
return (img * 255).clip(0, 255).astype('uint8') return (img * 255).clip(0, 255).astype('uint8')
def post_process(self, raw_color, orig): def post_process(self, raw_color, orig):
color_np = np.asarray(raw_color) color_np = np.asarray(raw_color)
orig_np = np.asarray(orig) orig_np = np.asarray(orig)
...@@ -86,11 +90,11 @@ class DeOldifyPredictor(): ...@@ -86,11 +90,11 @@ class DeOldifyPredictor():
final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
final = Image.fromarray(final) final = Image.fromarray(final)
return final return final
def run_single(self, img_path): def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB') ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img) img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis,...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) out = self.model(x)
pred_img = self.denorm(out.numpy()[0]) pred_img = self.denorm(out.numpy()[0])
...@@ -118,20 +122,20 @@ class DeOldifyPredictor(): ...@@ -118,20 +122,20 @@ class DeOldifyPredictor():
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames): for frame in tqdm(frames):
pred_img = self.run_single(frame) pred_img = self.run_single(frame)
frame_name = os.path.basename(frame) frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name)) pred_img.save(os.path.join(pred_frame_path, frame_name))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps)))
return frame_pattern_combined, vid_out_path vid_out_path = os.path.join(output_path,
'{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None: if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [ cmd = ffmpeg + [
' -ss ', ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
ss, ' 0.1 ', ' -start_number ', ' 0 ', outformat
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
] ]
else: else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
...@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
return out_full_path return out_full_path
if __name__=='__main__': if __name__ == '__main__':
paddle.enable_imperative() paddle.enable_imperative()
args = parser.parse_args() args = parser.parse_args()
predictor = DeOldifyPredictor(args.input, args.output, weight_path=args.weight_path) predictor = DeOldifyPredictor(args.input,
args.output,
weight_path=args.weight_path)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path) print('output video path:', temp_video_path)
\ No newline at end of file
...@@ -15,195 +15,235 @@ import argparse ...@@ -15,195 +15,235 @@ import argparse
import subprocess import subprocess
import utils import utils
from remasternet import NetworkR, NetworkC from remasternet import NetworkR, NetworkC
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering') parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video') parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory') parser.add_argument('--reference_dir',
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster without colorization') type=str,
parser.add_argument('--mindim', type=int, default='360', help='Length of minimum image edges') default=None,
help='Path to the reference image directory')
parser.add_argument('--colorization',
action='store_true',
default=False,
help='Remaster without colorization')
parser.add_argument('--mindim',
type=int,
default='360',
help='Length of minimum image edges')
class DeepReasterPredictor: class DeepReasterPredictor:
def __init__(self, input, output, weight_path=None, colorization=False, reference_dir=None, mindim=360): def __init__(self,
self.input = input input,
self.output = os.path.join(output, 'DeepRemaster') output,
self.colorization = colorization weight_path=None,
self.reference_dir = reference_dir colorization=False,
self.mindim = mindim reference_dir=None,
mindim=360):
if weight_path is None: self.input = input
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) self.output = os.path.join(output, 'DeepRemaster')
self.colorization = colorization
state_dict, _ = paddle.load(weight_path) self.reference_dir = reference_dir
self.mindim = mindim
self.modelR = NetworkR()
self.modelR.load_dict(state_dict['modelR']) if weight_path is None:
self.modelR.eval() weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path)
if colorization:
self.modelC = NetworkC() state_dict, _ = paddle.load(weight_path)
self.modelC.load_dict(state_dict['modelC'])
self.modelC.eval() self.modelR = NetworkR()
self.modelR.load_dict(state_dict['modelR'])
self.modelR.eval()
def run(self): if colorization:
outputdir = self.output self.modelC = NetworkC()
outputdir_in = os.path.join(outputdir, 'input/') self.modelC.load_dict(state_dict['modelC'])
os.makedirs( outputdir_in, exist_ok=True ) self.modelC.eval()
outputdir_out = os.path.join(outputdir, 'output/')
os.makedirs( outputdir_out, exist_ok=True ) def run(self):
outputdir = self.output
# Prepare reference images outputdir_in = os.path.join(outputdir, 'input/')
if self.colorization: os.makedirs(outputdir_in, exist_ok=True)
if self.reference_dir is not None: outputdir_out = os.path.join(outputdir, 'output/')
import glob os.makedirs(outputdir_out, exist_ok=True)
ext_list = ['png','jpg','bmp']
reference_files = [] # Prepare reference images
for ext in ext_list: if self.colorization:
reference_files += glob.glob( self.reference_dir+'/*.'+ext, recursive=True ) if self.reference_dir is not None:
aspect_mean = 0 import glob
minedge_dim = 256 ext_list = ['png', 'jpg', 'bmp']
refs = [] reference_files = []
for v in reference_files: for ext in ext_list:
refimg = Image.open( v ).convert('RGB') reference_files += glob.glob(self.reference_dir + '/*.' +
w, h = refimg.size ext,
aspect_mean += w/h recursive=True)
refs.append( refimg ) aspect_mean = 0
aspect_mean /= len(reference_files) minedge_dim = 256
target_w = int(256*aspect_mean) if aspect_mean>1 else 256 refs = []
target_h = 256 if aspect_mean>=1 else int(256/aspect_mean) for v in reference_files:
refimg = Image.open(v).convert('RGB')
refimgs = [] w, h = refimg.size
for i, v in enumerate(refs): aspect_mean += w / h
refimg = utils.addMergin( v, target_w=target_w, target_h=target_h ) refs.append(refimg)
refimg = np.array(refimg).astype('float32').transpose(2, 0, 1) / 255.0 aspect_mean /= len(reference_files)
refimgs.append(refimg) target_w = int(256 * aspect_mean) if aspect_mean > 1 else 256
refimgs = paddle.to_tensor(np.array(refimgs).astype('float32')) target_h = 256 if aspect_mean >= 1 else int(256 / aspect_mean)
refimgs = paddle.unsqueeze(refimgs, 0) refimgs = []
for i, v in enumerate(refs):
# Load video refimg = utils.addMergin(v,
cap = cv2.VideoCapture( self.input ) target_w=target_w,
nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) target_h=target_h)
v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) refimg = np.array(refimg).astype('float32').transpose(
v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) 2, 0, 1) / 255.0
minwh = min(v_w,v_h) refimgs.append(refimg)
scale = 1 refimgs = paddle.to_tensor(np.array(refimgs).astype('float32'))
if minwh != self.mindim:
scale = self.mindim / minwh refimgs = paddle.unsqueeze(refimgs, 0)
t_w = round(v_w*scale/16.)*16 # Load video
t_h = round(v_h*scale/16.)*16 cap = cv2.VideoCapture(self.input)
fps = cap.get(cv2.CAP_PROP_FPS) nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
pbar = tqdm(total=nframes) v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
block = 5 v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
minwh = min(v_w, v_h)
# Process scale = 1
with paddle.no_grad(): if minwh != self.mindim:
it = 0 scale = self.mindim / minwh
while True:
frame_pos = it*block t_w = round(v_w * scale / 16.) * 16
if frame_pos >= nframes: t_h = round(v_h * scale / 16.) * 16
break fps = cap.get(cv2.CAP_PROP_FPS)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos) pbar = tqdm(total=nframes)
if block >= nframes-frame_pos: block = 5
proc_g = nframes-frame_pos
else: # Process
proc_g = block with paddle.no_grad():
it = 0
input = None while True:
gtC = None frame_pos = it * block
for i in range(proc_g): if frame_pos >= nframes:
index = frame_pos + i break
_, frame = cap.read() cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
frame = cv2.resize(frame, (t_w, t_h)) if block >= nframes - frame_pos:
nchannels = frame.shape[2] proc_g = nframes - frame_pos
if nchannels == 1 or self.colorization: else:
frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) proc_g = block
cv2.imwrite(outputdir_in+'%07d.png'%index, frame_l)
frame_l = paddle.to_tensor(frame_l.astype('float32')) input = None
frame_l = paddle.reshape(frame_l, [frame_l.shape[0], frame_l.shape[1], 1]) gtC = None
frame_l = paddle.transpose(frame_l, [2, 0, 1]) for i in range(proc_g):
frame_l /= 255. index = frame_pos + i
_, frame = cap.read()
frame_l = paddle.reshape(frame_l, [1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) frame = cv2.resize(frame, (t_w, t_h))
elif nchannels == 3: nchannels = frame.shape[2]
cv2.imwrite(outputdir_in+'%07d.png'%index, frame) if nchannels == 1 or self.colorization:
frame = frame[:,:,::-1] ## BGR -> RGB frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame_l, frame_ab = utils.convertRGB2LABTensor( frame ) cv2.imwrite(outputdir_in + '%07d.png' % index, frame_l)
frame_l = frame_l.transpose([2, 0, 1]) frame_l = paddle.to_tensor(frame_l.astype('float32'))
frame_ab = frame_ab.transpose([2, 0, 1]) frame_l = paddle.reshape(
frame_l = frame_l.reshape([1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) frame_l, [frame_l.shape[0], frame_l.shape[1], 1])
frame_ab = frame_ab.reshape([1, frame_ab.shape[0], 1, frame_ab.shape[1], frame_ab.shape[2]]) frame_l = paddle.transpose(frame_l, [2, 0, 1])
frame_l /= 255.
if input is not None:
paddle.concat( (input, frame_l), 2 ) frame_l = paddle.reshape(frame_l, [
1, frame_l.shape[0], 1, frame_l.shape[1],
input = frame_l if i==0 else paddle.concat( (input, frame_l), 2 ) frame_l.shape[2]
if nchannels==3 and not self.colorization: ])
gtC = frame_ab if i==0 else paddle.concat( (gtC, frame_ab), 2 ) elif nchannels == 3:
cv2.imwrite(outputdir_in + '%07d.png' % index, frame)
input = paddle.to_tensor(input) frame = frame[:, :, ::-1] ## BGR -> RGB
frame_l, frame_ab = utils.convertRGB2LABTensor(frame)
frame_l = frame_l.transpose([2, 0, 1])
output_l = self.modelR( input ) # [B, C, T, H, W] frame_ab = frame_ab.transpose([2, 0, 1])
frame_l = frame_l.reshape([
# Save restoration output without colorization when using the option [--disable_colorization] 1, frame_l.shape[0], 1, frame_l.shape[1],
if not self.colorization: frame_l.shape[2]
for i in range( proc_g ): ])
index = frame_pos + i frame_ab = frame_ab.reshape([
if nchannels==3: 1, frame_ab.shape[0], 1, frame_ab.shape[1],
out_l = output_l.detach()[0,:,i] frame_ab.shape[2]
out_ab = gtC[0,:,i] ])
out = paddle.concat((out_l, out_ab),axis=0).detach().numpy().transpose((1, 2, 0)) if input is not None:
out = Image.fromarray( np.uint8( utils.convertLAB2RGB( out )*255 ) ) paddle.concat((input, frame_l), 2)
out.save( outputdir_out+'%07d.png'%(index) )
else: input = frame_l if i == 0 else paddle.concat(
raise ValueError('channels of imag3 must be 3!') (input, frame_l), 2)
if nchannels == 3 and not self.colorization:
# Perform colorization gtC = frame_ab if i == 0 else paddle.concat(
else: (gtC, frame_ab), 2)
if self.reference_dir is None:
output_ab = self.modelC( output_l ) input = paddle.to_tensor(input)
else:
output_ab = self.modelC( output_l, refimgs ) output_l = self.modelR(input) # [B, C, T, H, W]
output_l = output_l.detach()
output_ab = output_ab.detach() # Save restoration output without colorization when using the option [--disable_colorization]
if not self.colorization:
for i in range(proc_g):
for i in range( proc_g ): index = frame_pos + i
index = frame_pos + i if nchannels == 3:
out_l = output_l[0,:,i,:,:] out_l = output_l.detach()[0, :, i]
out_c = output_ab[0,:,i,:,:] out_ab = gtC[0, :, i]
output = paddle.concat((out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) ) out = paddle.concat(
output.save( outputdir_out+'%07d.png'%index ) (out_l, out_ab),
axis=0).detach().numpy().transpose((1, 2, 0))
it = it + 1 out = Image.fromarray(
pbar.update(proc_g) np.uint8(utils.convertLAB2RGB(out) * 255))
out.save(outputdir_out + '%07d.png' % (index))
# Save result videos else:
outfile = os.path.join(outputdir, self.input.split('/')[-1].split('.')[0]) raise ValueError('channels of imag3 must be 3!')
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (fps, outputdir_in, fps, outfile )
subprocess.call( cmd, shell=True ) # Perform colorization
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (fps, outputdir_out, fps, outfile ) else:
subprocess.call( cmd, shell=True ) if self.reference_dir is None:
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( outfile, outfile, outfile ) output_ab = self.modelC(output_l)
subprocess.call( cmd, shell=True ) else:
output_ab = self.modelC(output_l, refimgs)
cap.release() output_l = output_l.detach()
pbar.close() output_ab = output_ab.detach()
return outputdir_out, '%s_out.mp4' % outfile
for i in range(proc_g):
index = frame_pos + i
out_l = output_l[0, :, i, :, :]
out_c = output_ab[0, :, i, :, :]
output = paddle.concat(
(out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output = Image.fromarray(
np.uint8(utils.convertLAB2RGB(output) * 255))
output.save(outputdir_out + '%07d.png' % index)
it = it + 1
pbar.update(proc_g)
# Save result videos
outfile = os.path.join(outputdir,
self.input.split('/')[-1].split('.')[0])
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (
fps, outputdir_in, fps, outfile)
subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (
fps, outputdir_out, fps, outfile)
subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % (
outfile, outfile, outfile)
subprocess.call(cmd, shell=True)
cap.release()
pbar.close()
return outputdir_out, '%s_out.mp4' % outfile
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
paddle.disable_static() paddle.disable_static()
predictor = DeepReasterPredictor(args.input, args.output, colorization=args.colorization, predictor = DeepReasterPredictor(args.input,
reference_dir=args.reference_dir, mindim=args.mindim) args.output,
predictor.run() colorization=args.colorization,
reference_dir=args.reference_dir,
\ No newline at end of file mindim=args.mindim)
predictor.run()
...@@ -28,30 +28,29 @@ import paddle.fluid as fluid ...@@ -28,30 +28,29 @@ import paddle.fluid as fluid
import cv2 import cv2
from data import EDVRDataset from data import EDVRDataset
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--input',
'--input', type=str,
type=str, default=None,
default=None, help='input video path')
help='input video path') parser.add_argument('--output',
parser.add_argument( type=str,
'--output', default='output',
type=str, help='output path')
default='output', parser.add_argument('--weight_path',
help='output path') type=str,
parser.add_argument( default=None,
'--weight_path', help='weight path')
type=str,
default=None,
help='weight path')
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_img(pred): def get_img(pred):
print('pred shape', pred.shape) print('pred shape', pred.shape)
pred = pred.squeeze() pred = pred.squeeze()
...@@ -59,10 +58,11 @@ def get_img(pred): ...@@ -59,10 +58,11 @@ def get_img(pred):
pred = pred * 255 pred = pred * 255
pred = pred.round() pred = pred.round()
pred = pred.astype('uint8') pred = pred.astype('uint8')
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
pred = pred[:, :, ::-1] # rgb -> bgr pred = pred[:, :, ::-1] # rgb -> bgr
return pred return pred
def save_img(img, framename): def save_img(img, framename):
dirname = os.path.dirname(framename) dirname = os.path.dirname(framename)
if not os.path.exists(dirname): if not os.path.exists(dirname):
...@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None: if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [ cmd = ffmpeg + [
' -ss ', ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
ss, ' 0.1 ', ' -start_number ', ' 0 ', outformat
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
] ]
else: else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
...@@ -134,20 +123,21 @@ class EDVRPredictor: ...@@ -134,20 +123,21 @@ class EDVRPredictor:
self.input = input self.input = input
self.output = os.path.join(output, 'EDVR') self.output = os.path.join(output, 'EDVR')
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.exe = fluid.Executor(place) self.exe = fluid.Executor(place)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(EDVR_weight_url, cur_path) weight_path = get_path_from_url(EDVR_weight_url, cur_path)
print(weight_path) print(weight_path)
model_filename = 'EDVR_model.pdmodel' model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams' params_filename = 'EDVR_params.pdparams'
out = fluid.io.load_inference_model(dirname=weight_path, out = fluid.io.load_inference_model(dirname=weight_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
executor=self.exe) executor=self.exe)
self.infer_prog, self.feed_list, self.fetch_list = out self.infer_prog, self.feed_list, self.fetch_list = out
...@@ -176,16 +166,19 @@ class EDVRPredictor: ...@@ -176,16 +166,19 @@ class EDVRPredictor:
cur_time = time.time() cur_time = time.time()
for infer_iter, data in enumerate(dataset): for infer_iter, data in enumerate(dataset):
data_feed_in = [data[0]] data_feed_in = [data[0]]
infer_outs = self.exe.run(self.infer_prog, infer_outs = self.exe.run(
fetch_list=self.fetch_list, self.infer_prog,
feed={self.feed_list[0]:np.array(data_feed_in)}) fetch_list=self.fetch_list,
feed={self.feed_list[0]: np.array(data_feed_in)})
infer_result_list = [item for item in infer_outs] infer_result_list = [item for item in infer_outs]
frame_path = data[1] frame_path = data[1]
img_i = get_img(infer_result_list[0]) img_i = get_img(infer_result_list[0])
save_img(img_i, os.path.join(pred_frame_path, os.path.basename(frame_path))) save_img(
img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path)))
prev_time = cur_time prev_time = cur_time
cur_time = time.time() cur_time = time.time()
...@@ -194,13 +187,15 @@ class EDVRPredictor: ...@@ -194,13 +187,15 @@ class EDVRPredictor:
print('Processed {} samples'.format(infer_iter + 1)) print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) vid_out_path = os.path.join(self.output,
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps))) '{}_edvr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args()
predictor = EDVRPredictor(args.input, args.output, args.weight_path) predictor = EDVRPredictor(args.input, args.output, args.weight_path)
predictor.run() predictor.run()
...@@ -60,11 +60,12 @@ dataset: ...@@ -60,11 +60,12 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
name: linear lr_scheduler:
learning_rate: 0.0002 name: linear
start_epoch: 100 learning_rate: 0.0002
decay_epochs: 100 start_epoch: 100
decay_epochs: 100
log_config: log_config:
interval: 100 interval: 100
...@@ -72,4 +73,3 @@ log_config: ...@@ -72,4 +73,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -59,11 +59,12 @@ dataset: ...@@ -59,11 +59,12 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
name: linear lr_scheduler:
learning_rate: 0.0002 name: linear
start_epoch: 100 learning_rate: 0.0002
decay_epochs: 100 start_epoch: 100
decay_epochs: 100
log_config: log_config:
interval: 100 interval: 100
...@@ -71,4 +72,3 @@ log_config: ...@@ -71,4 +72,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -25,7 +25,7 @@ dataset: ...@@ -25,7 +25,7 @@ dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
num_workers: 0 num_workers: 4
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
direction: BtoA direction: BtoA
...@@ -57,11 +57,12 @@ dataset: ...@@ -57,11 +57,12 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
name: linear lr_scheduler:
learning_rate: 0.0002 name: linear
start_epoch: 100 learning_rate: 0.0002
decay_epochs: 100 start_epoch: 100
decay_epochs: 100
log_config: log_config:
interval: 100 interval: 100
...@@ -69,4 +70,3 @@ log_config: ...@@ -69,4 +70,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -56,11 +56,12 @@ dataset: ...@@ -56,11 +56,12 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
name: linear lr_scheduler:
learning_rate: 0.0004 name: linear
start_epoch: 100 learning_rate: 0.0004
decay_epochs: 100 start_epoch: 100
decay_epochs: 100
log_config: log_config:
interval: 100 interval: 100
...@@ -68,4 +69,3 @@ log_config: ...@@ -68,4 +69,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -56,11 +56,12 @@ dataset: ...@@ -56,11 +56,12 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
name: linear lr_scheduler:
learning_rate: 0.0002 name: linear
start_epoch: 100 learning_rate: 0.0002
decay_epochs: 100 start_epoch: 100
decay_epochs: 100
log_config: log_config:
interval: 100 interval: 100
......
...@@ -6,7 +6,7 @@ from paddle.io import Dataset ...@@ -6,7 +6,7 @@ from paddle.io import Dataset
from PIL import Image from PIL import Image
import cv2 import cv2
import paddle.incubate.hapi.vision.transforms as transforms import paddle.vision.transforms as transforms
from .transforms import transforms as T from .transforms import transforms as T
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod ...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets. """This class is an abstract base class (ABC) for datasets.
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize the class; save the options in the class """Initialize the class; save the options in the class
...@@ -60,8 +59,11 @@ def get_params(cfg, size): ...@@ -60,8 +59,11 @@ def get_params(cfg, size):
return {'crop_pos': (x, y), 'flip': flip} return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg,
def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True): params=None,
grayscale=False,
method=cv2.INTER_CUBIC,
convert=True):
transform_list = [] transform_list = []
if grayscale: if grayscale:
print('grayscale not support for now!!!') print('grayscale not support for now!!!')
...@@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con ...@@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']: elif params['flip']:
transform_list.append(transforms.RandomHorizontalFlip(1.0)) transform_list.append(transforms.RandomHorizontalFlip(1.0))
if convert: if convert:
transform_list += [transforms.Permute(to_rgb=True)] transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))] transform_list += [
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
]
return transforms.Compose(transform_list) return transforms.Compose(transform_list)
...@@ -3,12 +3,11 @@ import paddle ...@@ -3,12 +3,11 @@ import paddle
import numbers import numbers
import numpy as np import numpy as np
from multiprocessing import Manager from multiprocessing import Manager
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry from ..utils.registry import Registry
DATASETS = Registry("DATASETS") DATASETS = Registry("DATASETS")
...@@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset): ...@@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset):
single_item = dataset[0] single_item = dataset[0]
self.keys = single_item.keys() self.keys = single_item.keys()
for k, v in single_item.items(): for k, v in single_item.items():
if not isinstance(v, (numbers.Number, np.ndarray)): if not isinstance(v, (numbers.Number, np.ndarray)):
setattr(self, k, Manager().dict()) setattr(self, k, Manager().dict())
...@@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset): ...@@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
ori_map = self.dataset[index] ori_map = self.dataset[index]
tmp_list = [] tmp_list = []
for k, v in ori_map.items(): for k, v in ori_map.items():
if isinstance(v, (numbers.Number, np.ndarray)): if isinstance(v, (numbers.Number, np.ndarray)):
tmp_list.append(v) tmp_list.append(v)
...@@ -60,17 +59,15 @@ class DictDataLoader(): ...@@ -60,17 +59,15 @@ class DictDataLoader():
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \ place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0) if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)
sampler = DistributedBatchSampler( sampler = DistributedBatchSampler(self.dataset,
self.dataset, batch_size=batch_size,
batch_size=batch_size, shuffle=True if is_train else False,
shuffle=True if is_train else False, drop_last=True if is_train else False)
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader( self.dataloader = paddle.io.DataLoader(self.dataset,
self.dataset, batch_sampler=sampler,
batch_sampler=sampler, places=place,
places=place, num_workers=num_workers)
num_workers=num_workers)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -83,7 +80,9 @@ class DictDataLoader(): ...@@ -83,7 +80,9 @@ class DictDataLoader():
j = 0 j = 0
for k in self.dataset.keys: for k in self.dataset.keys:
if k in self.dataset.tensor_keys_set: if k in self.dataset.tensor_keys_set:
return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data return_dict[k] = data[j] if isinstance(data,
(list,
tuple)) else data
j += 1 j += 1
else: else:
return_dict[k] = self.get_items_by_indexs(k, data[-1]) return_dict[k] = self.get_items_by_indexs(k, data[-1])
...@@ -104,13 +103,12 @@ class DictDataLoader(): ...@@ -104,13 +103,12 @@ class DictDataLoader():
return current_items return current_items
def build_dataloader(cfg, is_train=True): def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg) dataset = DATASETS.get(cfg.name)(cfg)
batch_size = cfg.get('batch_size', 1) batch_size = cfg.get('batch_size', 1)
num_workers = cfg.get('num_workers', 0) num_workers = cfg.get('num_workers', 0)
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
return dataloader return dataloader
\ No newline at end of file
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import logging import logging
import paddle import paddle
from paddle import ParallelEnv, DataParallel from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
...@@ -17,10 +17,11 @@ class Trainer: ...@@ -17,10 +17,11 @@ class Trainer:
# build train dataloader # build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train) self.train_dataloader = build_dataloader(cfg.dataset.train)
if 'lr_scheduler' in cfg.optimizer: if 'lr_scheduler' in cfg.optimizer:
cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader) cfg.optimizer.lr_scheduler.step_per_epoch = len(
self.train_dataloader)
# build model # build model
self.model = build_model(cfg) self.model = build_model(cfg)
# multiple gpus prepare # multiple gpus prepare
...@@ -44,16 +45,17 @@ class Trainer: ...@@ -44,16 +45,17 @@ class Trainer:
# time count # time count
self.time_count = {} self.time_count = {}
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.prepare_context() strategy = paddle.prepare_context()
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name, DataParallel(net, strategy)) setattr(self.model, 'net' + name,
paddle.DataParallel(net, strategy))
def train(self): def train(self):
for epoch in range(self.start_epoch, self.epochs): for epoch in range(self.start_epoch, self.epochs):
self.current_epoch = epoch self.current_epoch = epoch
start_time = step_start_time = time.time() start_time = step_start_time = time.time()
...@@ -64,24 +66,27 @@ class Trainer: ...@@ -64,24 +66,27 @@ class Trainer:
# data input should be dict # data input should be dict
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
self.data_time = data_time - step_start_time self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time self.step_time = time.time() - step_start_time
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.print_log() self.print_log()
if i % self.visual_interval == 0: if i % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
step_start_time = time.time() step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0: if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1) self.save(epoch, 'weight', keep=-1)
self.save(epoch) self.save(epoch)
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False) self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
...@@ -103,14 +108,15 @@ class Trainer: ...@@ -103,14 +108,15 @@ class Trainer:
visual_results.update({name: img_tensor[j]}) visual_results.update({name: img_tensor[j]})
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader))) self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id) message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
message += '%s: %.6f ' % ('lr', self.current_learning_rate) message += '%s: %.6f ' % ('lr', self.current_learning_rate)
for k, v in losses.items(): for k, v in losses.items():
...@@ -143,13 +149,14 @@ class Trainer: ...@@ -143,13 +149,14 @@ class Trainer:
makedirs(os.path.join(self.output_dir, results_dir)) makedirs(os.path.join(self.output_dir, results_dir))
for label, image in visual_results.items(): for label, image in visual_results.items():
image_numpy = tensor2img(image) image_numpy = tensor2img(image)
img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label)) img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label))
save_image(image_numpy, img_path) save_image(image_numpy, img_path)
def save(self, epoch, name='checkpoint', keep=1): def save(self, epoch, name='checkpoint', keep=1):
if self.local_rank != 0: if self.local_rank != 0:
return return
assert name in ['checkpoint', 'weight'] assert name in ['checkpoint', 'weight']
state_dicts = {} state_dicts = {}
...@@ -175,8 +182,8 @@ class Trainer: ...@@ -175,8 +182,8 @@ class Trainer:
if keep > 0: if keep > 0:
try: try:
checkpoint_name_to_be_removed = os.path.join(self.output_dir, checkpoint_name_to_be_removed = os.path.join(
'epoch_%s_%s.pkl' % (epoch - keep, name)) self.output_dir, 'epoch_%s_%s.pkl' % (epoch - keep, name))
if os.path.exists(checkpoint_name_to_be_removed): if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed) os.remove(checkpoint_name_to_be_removed)
...@@ -187,7 +194,7 @@ class Trainer: ...@@ -187,7 +194,7 @@ class Trainer:
state_dicts = load(checkpoint_path) state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None: if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
...@@ -200,9 +207,8 @@ class Trainer: ...@@ -200,9 +207,8 @@ class Trainer:
def load(self, weight_path): def load(self, weight_path):
state_dicts = load(weight_path) state_dicts = load(weight_path)
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name]) net.set_dict(state_dicts['net' + name])
\ No newline at end of file
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..solver.lr_scheduler import build_lr_scheduler
class BaseModel(ABC): class BaseModel(ABC):
...@@ -16,7 +17,6 @@ class BaseModel(ABC): ...@@ -16,7 +17,6 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights. -- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the BaseModel class. """Initialize the BaseModel class.
...@@ -33,8 +33,10 @@ class BaseModel(ABC): ...@@ -33,8 +33,10 @@ class BaseModel(ABC):
""" """
self.opt = opt self.opt = opt
self.isTrain = opt.isTrain self.isTrain = opt.isTrain
self.save_dir = os.path.join(opt.output_dir, opt.model.name) # save all the checkpoints to save_dir self.save_dir = os.path.join(
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
self.loss_names = [] self.loss_names = []
self.model_names = [] self.model_names = []
self.visual_names = [] self.visual_names = []
...@@ -75,6 +77,8 @@ class BaseModel(ABC): ...@@ -75,6 +77,8 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
pass pass
def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler)
def eval(self): def eval(self):
"""Make models eval mode during test time""" """Make models eval mode during test time"""
...@@ -114,10 +118,11 @@ class BaseModel(ABC): ...@@ -114,10 +118,11 @@ class BaseModel(ABC):
errors_ret = OrderedDict() errors_ret = OrderedDict()
for name in self.loss_names: for name in self.loss_names:
if isinstance(name, str): if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret return errors_ret
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters: Parameters:
......
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel): ...@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
...@@ -32,12 +31,14 @@ class CycleGANModel(BaseModel): ...@@ -32,12 +31,14 @@ class CycleGANModel(BaseModel):
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] self.loss_names = [
'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B']
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if self.isTrain and self.opt.lambda_identity > 0.0: if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_B') visual_names_A.append('idt_B')
visual_names_B.append('idt_A') visual_names_B.append('idt_A')
...@@ -62,18 +63,28 @@ class CycleGANModel(BaseModel): ...@@ -62,18 +63,28 @@ class CycleGANModel(BaseModel):
if self.isTrain: if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) assert (
opt.dataset.train.input_nc == opt.dataset.train.output_nc)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) self.fake_A_pool = ImagePool(opt.dataset.train.pool_size)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) self.fake_B_pool = ImagePool(opt.dataset.train.pool_size)
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss() self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) self.build_lr_scheduler()
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG_A.parameters() +
self.netG_B.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters() +
self.netD_B.parameters())
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
...@@ -90,7 +101,7 @@ class CycleGANModel(BaseModel): ...@@ -90,7 +101,7 @@ class CycleGANModel(BaseModel):
""" """
mode = 'train' if self.isTrain else 'test' mode = 'train' if self.isTrain else 'test'
AtoB = self.opt.dataset[mode].direction == 'AtoB' AtoB = self.opt.dataset[mode].direction == 'AtoB'
if AtoB: if AtoB:
if 'A' in input: if 'A' in input:
self.real_A = paddle.to_tensor(input['A']) self.real_A = paddle.to_tensor(input['A'])
...@@ -107,17 +118,15 @@ class CycleGANModel(BaseModel): ...@@ -107,17 +118,15 @@ class CycleGANModel(BaseModel):
elif 'B_paths' in input: elif 'B_paths' in input:
self.image_paths = input['B_paths'] self.image_paths = input['B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'): if hasattr(self, 'real_A'):
self.fake_B = self.netG_A(self.real_A) # G_A(A) self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
if hasattr(self, 'real_B'): if hasattr(self, 'real_B'):
self.fake_A = self.netG_B(self.real_B) # G_B(B) self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
...@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel): ...@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
if lambda_idt > 0: if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B|| # G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B) self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A|| # G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A) self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt
else: else:
self.loss_idt_A = 0 self.loss_idt_A = 0
self.loss_idt_B = 0 self.loss_idt_B = 0
...@@ -179,12 +190,14 @@ class CycleGANModel(BaseModel): ...@@ -179,12 +190,14 @@ class CycleGANModel(BaseModel):
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B|| # Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if ParallelEnv().nranks > 1: if ParallelEnv().nranks > 1:
self.loss_G = self.netG_A.scale_loss(self.loss_G) self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward() self.loss_G.backward()
...@@ -216,6 +229,5 @@ class CycleGANModel(BaseModel): ...@@ -216,6 +229,5 @@ class CycleGANModel(BaseModel):
self.backward_D_A() self.backward_D_A()
# calculate graidents for D_B # calculate graidents for D_B
self.backward_D_B() self.backward_D_B()
# update D_A and D_B's weights # update D_A and D_B's weights
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B)
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel): ...@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
...@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel): ...@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
if self.isTrain: if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator) self.netD = build_discriminator(opt.model.discriminator)
if self.isTrain: if self.isTrain:
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
# build optimizers # build optimizers
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) self.build_lr_scheduler()
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD.parameters())
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
...@@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel): ...@@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel):
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
...@@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel): ...@@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel):
def forward_test(self, input): def forward_test(self, input):
input = paddle.imperative.to_variable(input) input = paddle.imperative.to_variable(input)
return self.netG(input) return self.netG(input)
def backward_D(self): def backward_D(self):
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B # Fake; stop backprop to the generator by detaching fake_B
...@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel): ...@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
pred_fake = self.netD(fake_AB) pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B # Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 self.loss_G_L1 = self.criterionL1(self.fake_B,
self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1
...@@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel): ...@@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel):
# update D # update D
self.set_requires_grad(self.netD, True) self.set_requires_grad(self.netD, True)
self.optimizer_D.clear_gradients() self.optimizer_D.clear_gradients()
self.backward_D() self.backward_D()
self.optimizer_D.minimize(self.loss_D) self.optimizer_D.minimize(self.loss_D)
# update G # update G
self.set_requires_grad(self.netD, False) self.set_requires_grad(self.netD, False)
self.optimizer_G.clear_gradients() self.optimizer_G.clear_gradients()
self.backward_G() self.backward_G()
self.optimizer_G.minimize(self.loss_G) self.optimizer_G.minimize(self.loss_G)
...@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg): ...@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
# TODO: add more learning rate scheduler # TODO: add more learning rate scheduler
if name == 'linear': if name == 'linear':
return LinearDecay(**cfg)
def lambda_rule(epoch):
lr_l = 1.0 - max(
0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1)
return lr_l
scheduler = paddle.optimizer.lr_scheduler.LambdaLR(
cfg.learning_rate, lr_lambda=lambda_rule)
return scheduler
else: else:
raise NotImplementedError raise NotImplementedError
class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay): # paddle.optimizer.lr_scheduler
def __init__(self, learning_rate, step_per_epoch, start_epoch, decay_epochs): class LinearDecay(paddle.optimizer.lr_scheduler._LRScheduler):
def __init__(self, learning_rate, step_per_epoch, start_epoch,
decay_epochs):
super(LinearDecay, self).__init__() super(LinearDecay, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.start_epoch = start_epoch self.start_epoch = start_epoch
...@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay ...@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
def step(self): def step(self):
cur_epoch = int(self.step_num // self.step_per_epoch) cur_epoch = int(self.step_num // self.step_per_epoch)
decay_rate = 1.0 - max(0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1) decay_rate = 1.0 - max(
return self.create_lr_var(decay_rate * self.learning_rate) 0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1)
\ No newline at end of file return self.create_lr_var(decay_rate * self.learning_rate)
...@@ -4,13 +4,11 @@ import paddle ...@@ -4,13 +4,11 @@ import paddle
from .lr_scheduler import build_lr_scheduler from .lr_scheduler import build_lr_scheduler
def build_optimizer(cfg, parameter_list=None): def build_optimizer(cfg, lr_scheduler, parameter_list=None):
cfg_copy = copy.deepcopy(cfg) cfg_copy = copy.deepcopy(cfg)
lr_scheduler_cfg = cfg_copy.pop('lr_scheduler', None)
lr_scheduler = build_lr_scheduler(lr_scheduler_cfg)
opt_name = cfg_copy.pop('name') opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler, parameters=parameter_list, **cfg_copy) return getattr(paddle.optimizer, opt_name)(lr_scheduler,
parameters=parameter_list,
**cfg_copy)
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
import sys import sys
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
def setup_logger(output=None, name="ppgan"): def setup_logger(output=None, name="ppgan"):
...@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"): ...@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
logger.propagate = False logger.propagate = False
plain_formatter = logging.Formatter( plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
) datefmt="%m/%d %H:%M:%S")
# stdout logging: master only # stdout logging: master only
local_rank = ParallelEnv().local_rank local_rank = ParallelEnv().local_rank
if local_rank == 0: if local_rank == 0:
...@@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"): ...@@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"):
fh.setFormatter(plain_formatter) fh.setFormatter(plain_formatter)
logger.addHandler(fh) logger.addHandler(fh)
return logger return logger
\ No newline at end of file
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import time import time
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .logger import setup_logger from .logger import setup_logger
...@@ -12,7 +12,8 @@ def setup(args, cfg): ...@@ -12,7 +12,8 @@ def setup(args, cfg):
cfg.isTrain = False cfg.isTrain = False
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, str(cfg.model.name) + cfg.timestamp) cfg.output_dir = os.path.join(cfg.output_dir,
str(cfg.model.name) + cfg.timestamp)
logger = setup_logger(cfg.output_dir) logger = setup_logger(cfg.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册