未验证 提交 b306aa73 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #16 from LielinJiang/adapt-to-2.0-api

Adapt to api 2.0 again
...@@ -11,7 +11,7 @@ from imageio import imread, imsave ...@@ -11,7 +11,7 @@ from imageio import imread, imsave
import cv2 import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
import networks import networks
from util import * from util import *
...@@ -19,6 +19,7 @@ from my_args import parser ...@@ -19,6 +19,7 @@ from my_args import parser
DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar' DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
def infer_engine(model_dir, def infer_engine(model_dir,
run_mode='fluid', run_mode='fluid',
batch_size=1, batch_size=1,
...@@ -91,7 +92,6 @@ class VideoFrameInterp(object): ...@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
self.exe, self.program, self.fetch_targets = executor(model_path, self.exe, self.program, self.fetch_targets = executor(model_path,
use_gpu=use_gpu) use_gpu=use_gpu)
def run(self): def run(self):
frame_path_input = os.path.join(self.output_path, 'frames-input') frame_path_input = os.path.join(self.output_path, 'frames-input')
frame_path_interpolated = os.path.join(self.output_path, frame_path_interpolated = os.path.join(self.output_path,
......
...@@ -15,15 +15,19 @@ from PIL import Image ...@@ -15,15 +15,19 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from paddle import fluid from paddle import fluid
from model import build_model from model import build_model
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify') parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path', type=str, default='none', help='Path to the reference image directory') parser.add_argument('--weight_path',
type=str,
default='none',
help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r): def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [
...@@ -90,7 +94,7 @@ class DeOldifyPredictor(): ...@@ -90,7 +94,7 @@ class DeOldifyPredictor():
def run_single(self, img_path): def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB') ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img) img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis,...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) out = self.model(x)
pred_img = self.denorm(out.numpy()[0]) pred_img = self.denorm(out.numpy()[0])
...@@ -118,7 +122,6 @@ class DeOldifyPredictor(): ...@@ -118,7 +122,6 @@ class DeOldifyPredictor():
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames): for frame in tqdm(frames):
pred_img = self.run_single(frame) pred_img = self.run_single(frame)
...@@ -127,13 +130,14 @@ class DeOldifyPredictor(): ...@@ -127,13 +130,14 @@ class DeOldifyPredictor():
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_deoldify_out.mp4'.format(base_name)) vid_out_path = os.path.join(output_path,
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps))) '{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0] vid_name = vid_path.split('/')[-1].split('.')[0]
...@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None: if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [ cmd = ffmpeg + [
' -ss ', ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
ss, ' 0.1 ', ' -start_number ', ' 0 ', outformat
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
] ]
else: else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
...@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
return out_full_path return out_full_path
if __name__=='__main__': if __name__ == '__main__':
paddle.enable_imperative() paddle.enable_imperative()
args = parser.parse_args() args = parser.parse_args()
predictor = DeOldifyPredictor(args.input, args.output, weight_path=args.weight_path) predictor = DeOldifyPredictor(args.input,
args.output,
weight_path=args.weight_path)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path) print('output video path:', temp_video_path)
...@@ -15,20 +15,35 @@ import argparse ...@@ -15,20 +15,35 @@ import argparse
import subprocess import subprocess
import utils import utils
from remasternet import NetworkR, NetworkC from remasternet import NetworkR, NetworkC
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering') parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video') parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory') parser.add_argument('--reference_dir',
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster without colorization') type=str,
parser.add_argument('--mindim', type=int, default='360', help='Length of minimum image edges') default=None,
help='Path to the reference image directory')
parser.add_argument('--colorization',
action='store_true',
default=False,
help='Remaster without colorization')
parser.add_argument('--mindim',
type=int,
default='360',
help='Length of minimum image edges')
class DeepReasterPredictor: class DeepReasterPredictor:
def __init__(self, input, output, weight_path=None, colorization=False, reference_dir=None, mindim=360): def __init__(self,
input,
output,
weight_path=None,
colorization=False,
reference_dir=None,
mindim=360):
self.input = input self.input = input
self.output = os.path.join(output, 'DeepRemaster') self.output = os.path.join(output, 'DeepRemaster')
self.colorization = colorization self.colorization = colorization
...@@ -48,55 +63,59 @@ class DeepReasterPredictor: ...@@ -48,55 +63,59 @@ class DeepReasterPredictor:
self.modelC.load_dict(state_dict['modelC']) self.modelC.load_dict(state_dict['modelC'])
self.modelC.eval() self.modelC.eval()
def run(self): def run(self):
outputdir = self.output outputdir = self.output
outputdir_in = os.path.join(outputdir, 'input/') outputdir_in = os.path.join(outputdir, 'input/')
os.makedirs( outputdir_in, exist_ok=True ) os.makedirs(outputdir_in, exist_ok=True)
outputdir_out = os.path.join(outputdir, 'output/') outputdir_out = os.path.join(outputdir, 'output/')
os.makedirs( outputdir_out, exist_ok=True ) os.makedirs(outputdir_out, exist_ok=True)
# Prepare reference images # Prepare reference images
if self.colorization: if self.colorization:
if self.reference_dir is not None: if self.reference_dir is not None:
import glob import glob
ext_list = ['png','jpg','bmp'] ext_list = ['png', 'jpg', 'bmp']
reference_files = [] reference_files = []
for ext in ext_list: for ext in ext_list:
reference_files += glob.glob( self.reference_dir+'/*.'+ext, recursive=True ) reference_files += glob.glob(self.reference_dir + '/*.' +
ext,
recursive=True)
aspect_mean = 0 aspect_mean = 0
minedge_dim = 256 minedge_dim = 256
refs = [] refs = []
for v in reference_files: for v in reference_files:
refimg = Image.open( v ).convert('RGB') refimg = Image.open(v).convert('RGB')
w, h = refimg.size w, h = refimg.size
aspect_mean += w/h aspect_mean += w / h
refs.append( refimg ) refs.append(refimg)
aspect_mean /= len(reference_files) aspect_mean /= len(reference_files)
target_w = int(256*aspect_mean) if aspect_mean>1 else 256 target_w = int(256 * aspect_mean) if aspect_mean > 1 else 256
target_h = 256 if aspect_mean>=1 else int(256/aspect_mean) target_h = 256 if aspect_mean >= 1 else int(256 / aspect_mean)
refimgs = [] refimgs = []
for i, v in enumerate(refs): for i, v in enumerate(refs):
refimg = utils.addMergin( v, target_w=target_w, target_h=target_h ) refimg = utils.addMergin(v,
refimg = np.array(refimg).astype('float32').transpose(2, 0, 1) / 255.0 target_w=target_w,
target_h=target_h)
refimg = np.array(refimg).astype('float32').transpose(
2, 0, 1) / 255.0
refimgs.append(refimg) refimgs.append(refimg)
refimgs = paddle.to_tensor(np.array(refimgs).astype('float32')) refimgs = paddle.to_tensor(np.array(refimgs).astype('float32'))
refimgs = paddle.unsqueeze(refimgs, 0) refimgs = paddle.unsqueeze(refimgs, 0)
# Load video # Load video
cap = cv2.VideoCapture( self.input ) cap = cv2.VideoCapture(self.input)
nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
minwh = min(v_w,v_h) minwh = min(v_w, v_h)
scale = 1 scale = 1
if minwh != self.mindim: if minwh != self.mindim:
scale = self.mindim / minwh scale = self.mindim / minwh
t_w = round(v_w*scale/16.)*16 t_w = round(v_w * scale / 16.) * 16
t_h = round(v_h*scale/16.)*16 t_h = round(v_h * scale / 16.) * 16
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
pbar = tqdm(total=nframes) pbar = tqdm(total=nframes)
block = 5 block = 5
...@@ -105,12 +124,12 @@ class DeepReasterPredictor: ...@@ -105,12 +124,12 @@ class DeepReasterPredictor:
with paddle.no_grad(): with paddle.no_grad():
it = 0 it = 0
while True: while True:
frame_pos = it*block frame_pos = it * block
if frame_pos >= nframes: if frame_pos >= nframes:
break break
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
if block >= nframes-frame_pos: if block >= nframes - frame_pos:
proc_g = nframes-frame_pos proc_g = nframes - frame_pos
else: else:
proc_g = block proc_g = block
...@@ -123,77 +142,96 @@ class DeepReasterPredictor: ...@@ -123,77 +142,96 @@ class DeepReasterPredictor:
nchannels = frame.shape[2] nchannels = frame.shape[2]
if nchannels == 1 or self.colorization: if nchannels == 1 or self.colorization:
frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cv2.imwrite(outputdir_in+'%07d.png'%index, frame_l) cv2.imwrite(outputdir_in + '%07d.png' % index, frame_l)
frame_l = paddle.to_tensor(frame_l.astype('float32')) frame_l = paddle.to_tensor(frame_l.astype('float32'))
frame_l = paddle.reshape(frame_l, [frame_l.shape[0], frame_l.shape[1], 1]) frame_l = paddle.reshape(
frame_l, [frame_l.shape[0], frame_l.shape[1], 1])
frame_l = paddle.transpose(frame_l, [2, 0, 1]) frame_l = paddle.transpose(frame_l, [2, 0, 1])
frame_l /= 255. frame_l /= 255.
frame_l = paddle.reshape(frame_l, [1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) frame_l = paddle.reshape(frame_l, [
1, frame_l.shape[0], 1, frame_l.shape[1],
frame_l.shape[2]
])
elif nchannels == 3: elif nchannels == 3:
cv2.imwrite(outputdir_in+'%07d.png'%index, frame) cv2.imwrite(outputdir_in + '%07d.png' % index, frame)
frame = frame[:,:,::-1] ## BGR -> RGB frame = frame[:, :, ::-1] ## BGR -> RGB
frame_l, frame_ab = utils.convertRGB2LABTensor( frame ) frame_l, frame_ab = utils.convertRGB2LABTensor(frame)
frame_l = frame_l.transpose([2, 0, 1]) frame_l = frame_l.transpose([2, 0, 1])
frame_ab = frame_ab.transpose([2, 0, 1]) frame_ab = frame_ab.transpose([2, 0, 1])
frame_l = frame_l.reshape([1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) frame_l = frame_l.reshape([
frame_ab = frame_ab.reshape([1, frame_ab.shape[0], 1, frame_ab.shape[1], frame_ab.shape[2]]) 1, frame_l.shape[0], 1, frame_l.shape[1],
frame_l.shape[2]
])
frame_ab = frame_ab.reshape([
1, frame_ab.shape[0], 1, frame_ab.shape[1],
frame_ab.shape[2]
])
if input is not None: if input is not None:
paddle.concat( (input, frame_l), 2 ) paddle.concat((input, frame_l), 2)
input = frame_l if i==0 else paddle.concat( (input, frame_l), 2 ) input = frame_l if i == 0 else paddle.concat(
if nchannels==3 and not self.colorization: (input, frame_l), 2)
gtC = frame_ab if i==0 else paddle.concat( (gtC, frame_ab), 2 ) if nchannels == 3 and not self.colorization:
gtC = frame_ab if i == 0 else paddle.concat(
(gtC, frame_ab), 2)
input = paddle.to_tensor(input) input = paddle.to_tensor(input)
output_l = self.modelR(input) # [B, C, T, H, W]
output_l = self.modelR( input ) # [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization] # Save restoration output without colorization when using the option [--disable_colorization]
if not self.colorization: if not self.colorization:
for i in range( proc_g ): for i in range(proc_g):
index = frame_pos + i index = frame_pos + i
if nchannels==3: if nchannels == 3:
out_l = output_l.detach()[0,:,i] out_l = output_l.detach()[0, :, i]
out_ab = gtC[0,:,i] out_ab = gtC[0, :, i]
out = paddle.concat((out_l, out_ab),axis=0).detach().numpy().transpose((1, 2, 0)) out = paddle.concat(
out = Image.fromarray( np.uint8( utils.convertLAB2RGB( out )*255 ) ) (out_l, out_ab),
out.save( outputdir_out+'%07d.png'%(index) ) axis=0).detach().numpy().transpose((1, 2, 0))
out = Image.fromarray(
np.uint8(utils.convertLAB2RGB(out) * 255))
out.save(outputdir_out + '%07d.png' % (index))
else: else:
raise ValueError('channels of imag3 must be 3!') raise ValueError('channels of imag3 must be 3!')
# Perform colorization # Perform colorization
else: else:
if self.reference_dir is None: if self.reference_dir is None:
output_ab = self.modelC( output_l ) output_ab = self.modelC(output_l)
else: else:
output_ab = self.modelC( output_l, refimgs ) output_ab = self.modelC(output_l, refimgs)
output_l = output_l.detach() output_l = output_l.detach()
output_ab = output_ab.detach() output_ab = output_ab.detach()
for i in range(proc_g):
for i in range( proc_g ):
index = frame_pos + i index = frame_pos + i
out_l = output_l[0,:,i,:,:] out_l = output_l[0, :, i, :, :]
out_c = output_ab[0,:,i,:,:] out_c = output_ab[0, :, i, :, :]
output = paddle.concat((out_l, out_c), axis=0).numpy().transpose((1, 2, 0)) output = paddle.concat(
output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) ) (out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output.save( outputdir_out+'%07d.png'%index ) output = Image.fromarray(
np.uint8(utils.convertLAB2RGB(output) * 255))
output.save(outputdir_out + '%07d.png' % index)
it = it + 1 it = it + 1
pbar.update(proc_g) pbar.update(proc_g)
# Save result videos # Save result videos
outfile = os.path.join(outputdir, self.input.split('/')[-1].split('.')[0]) outfile = os.path.join(outputdir,
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (fps, outputdir_in, fps, outfile ) self.input.split('/')[-1].split('.')[0])
subprocess.call( cmd, shell=True ) cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (fps, outputdir_out, fps, outfile ) fps, outputdir_in, fps, outfile)
subprocess.call( cmd, shell=True ) subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( outfile, outfile, outfile ) cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (
subprocess.call( cmd, shell=True ) fps, outputdir_out, fps, outfile)
subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % (
outfile, outfile, outfile)
subprocess.call(cmd, shell=True)
cap.release() cap.release()
pbar.close() pbar.close()
...@@ -203,7 +241,9 @@ class DeepReasterPredictor: ...@@ -203,7 +241,9 @@ class DeepReasterPredictor:
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
paddle.disable_static() paddle.disable_static()
predictor = DeepReasterPredictor(args.input, args.output, colorization=args.colorization, predictor = DeepReasterPredictor(args.input,
reference_dir=args.reference_dir, mindim=args.mindim) args.output,
colorization=args.colorization,
reference_dir=args.reference_dir,
mindim=args.mindim)
predictor.run() predictor.run()
\ No newline at end of file
...@@ -28,30 +28,29 @@ import paddle.fluid as fluid ...@@ -28,30 +28,29 @@ import paddle.fluid as fluid
import cv2 import cv2
from data import EDVRDataset from data import EDVRDataset
from paddle.incubate.hapi.download import get_path_from_url from paddle.utils.download import get_path_from_url
EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--input',
'--input',
type=str, type=str,
default=None, default=None,
help='input video path') help='input video path')
parser.add_argument( parser.add_argument('--output',
'--output',
type=str, type=str,
default='output', default='output',
help='output path') help='output path')
parser.add_argument( parser.add_argument('--weight_path',
'--weight_path',
type=str, type=str,
default=None, default=None,
help='weight path') help='weight path')
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_img(pred): def get_img(pred):
print('pred shape', pred.shape) print('pred shape', pred.shape)
pred = pred.squeeze() pred = pred.squeeze()
...@@ -63,6 +62,7 @@ def get_img(pred): ...@@ -63,6 +62,7 @@ def get_img(pred):
pred = pred[:, :, ::-1] # rgb -> bgr pred = pred[:, :, ::-1] # rgb -> bgr
return pred return pred
def save_img(img, framename): def save_img(img, framename):
dirname = os.path.dirname(framename) dirname = os.path.dirname(framename)
if not os.path.exists(dirname): if not os.path.exists(dirname):
...@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None: if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [ cmd = ffmpeg + [
' -ss ', ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
ss, ' 0.1 ', ' -start_number ', ' 0 ', outformat
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
] ]
else: else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
...@@ -134,7 +123,8 @@ class EDVRPredictor: ...@@ -134,7 +123,8 @@ class EDVRPredictor:
self.input = input self.input = input
self.output = os.path.join(output, 'EDVR') self.output = os.path.join(output, 'EDVR')
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.exe = fluid.Executor(place) self.exe = fluid.Executor(place)
if weight_path is None: if weight_path is None:
...@@ -177,15 +167,18 @@ class EDVRPredictor: ...@@ -177,15 +167,18 @@ class EDVRPredictor:
for infer_iter, data in enumerate(dataset): for infer_iter, data in enumerate(dataset):
data_feed_in = [data[0]] data_feed_in = [data[0]]
infer_outs = self.exe.run(self.infer_prog, infer_outs = self.exe.run(
self.infer_prog,
fetch_list=self.fetch_list, fetch_list=self.fetch_list,
feed={self.feed_list[0]:np.array(data_feed_in)}) feed={self.feed_list[0]: np.array(data_feed_in)})
infer_result_list = [item for item in infer_outs] infer_result_list = [item for item in infer_outs]
frame_path = data[1] frame_path = data[1]
img_i = get_img(infer_result_list[0]) img_i = get_img(infer_result_list[0])
save_img(img_i, os.path.join(pred_frame_path, os.path.basename(frame_path))) save_img(
img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path)))
prev_time = cur_time prev_time = cur_time
cur_time = time.time() cur_time = time.time()
...@@ -194,13 +187,15 @@ class EDVRPredictor: ...@@ -194,13 +187,15 @@ class EDVRPredictor:
print('Processed {} samples'.format(infer_iter + 1)) print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) vid_out_path = os.path.join(self.output,
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps))) '{}_edvr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args()
predictor = EDVRPredictor(args.input, args.output, args.weight_path) predictor = EDVRPredictor(args.input, args.output, args.weight_path)
predictor.run() predictor.run()
...@@ -60,7 +60,8 @@ dataset: ...@@ -60,7 +60,8 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
lr_scheduler:
name: linear name: linear
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
...@@ -72,4 +73,3 @@ log_config: ...@@ -72,4 +73,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -59,7 +59,8 @@ dataset: ...@@ -59,7 +59,8 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
lr_scheduler:
name: linear name: linear
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
...@@ -71,4 +72,3 @@ log_config: ...@@ -71,4 +72,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -25,7 +25,7 @@ dataset: ...@@ -25,7 +25,7 @@ dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
num_workers: 0 num_workers: 4
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
direction: BtoA direction: BtoA
...@@ -57,7 +57,8 @@ dataset: ...@@ -57,7 +57,8 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
lr_scheduler:
name: linear name: linear
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
...@@ -69,4 +70,3 @@ log_config: ...@@ -69,4 +70,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -56,7 +56,8 @@ dataset: ...@@ -56,7 +56,8 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
lr_scheduler:
name: linear name: linear
learning_rate: 0.0004 learning_rate: 0.0004
start_epoch: 100 start_epoch: 100
...@@ -68,4 +69,3 @@ log_config: ...@@ -68,4 +69,3 @@ log_config:
snapshot_config: snapshot_config:
interval: 5 interval: 5
...@@ -56,7 +56,8 @@ dataset: ...@@ -56,7 +56,8 @@ dataset:
optimizer: optimizer:
name: Adam name: Adam
beta1: 0.5 beta1: 0.5
lr_scheduler:
lr_scheduler:
name: linear name: linear
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
......
...@@ -6,7 +6,7 @@ from paddle.io import Dataset ...@@ -6,7 +6,7 @@ from paddle.io import Dataset
from PIL import Image from PIL import Image
import cv2 import cv2
import paddle.incubate.hapi.vision.transforms as transforms import paddle.vision.transforms as transforms
from .transforms import transforms as T from .transforms import transforms as T
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod ...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets. """This class is an abstract base class (ABC) for datasets.
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize the class; save the options in the class """Initialize the class; save the options in the class
...@@ -60,8 +59,11 @@ def get_params(cfg, size): ...@@ -60,8 +59,11 @@ def get_params(cfg, size):
return {'crop_pos': (x, y), 'flip': flip} return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg,
def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True): params=None,
grayscale=False,
method=cv2.INTER_CUBIC,
convert=True):
transform_list = [] transform_list = []
if grayscale: if grayscale:
print('grayscale not support for now!!!') print('grayscale not support for now!!!')
...@@ -92,5 +94,7 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con ...@@ -92,5 +94,7 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
if convert: if convert:
transform_list += [transforms.Permute(to_rgb=True)] transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))] transform_list += [
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
]
return transforms.Compose(transform_list) return transforms.Compose(transform_list)
...@@ -3,12 +3,11 @@ import paddle ...@@ -3,12 +3,11 @@ import paddle
import numbers import numbers
import numpy as np import numpy as np
from multiprocessing import Manager from multiprocessing import Manager
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry from ..utils.registry import Registry
DATASETS = Registry("DATASETS") DATASETS = Registry("DATASETS")
...@@ -60,14 +59,12 @@ class DictDataLoader(): ...@@ -60,14 +59,12 @@ class DictDataLoader():
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \ place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0) if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)
sampler = DistributedBatchSampler( sampler = DistributedBatchSampler(self.dataset,
self.dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True if is_train else False, shuffle=True if is_train else False,
drop_last=True if is_train else False) drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader( self.dataloader = paddle.io.DataLoader(self.dataset,
self.dataset,
batch_sampler=sampler, batch_sampler=sampler,
places=place, places=place,
num_workers=num_workers) num_workers=num_workers)
...@@ -83,7 +80,9 @@ class DictDataLoader(): ...@@ -83,7 +80,9 @@ class DictDataLoader():
j = 0 j = 0
for k in self.dataset.keys: for k in self.dataset.keys:
if k in self.dataset.tensor_keys_set: if k in self.dataset.tensor_keys_set:
return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data return_dict[k] = data[j] if isinstance(data,
(list,
tuple)) else data
j += 1 j += 1
else: else:
return_dict[k] = self.get_items_by_indexs(k, data[-1]) return_dict[k] = self.get_items_by_indexs(k, data[-1])
...@@ -104,7 +103,6 @@ class DictDataLoader(): ...@@ -104,7 +103,6 @@ class DictDataLoader():
return current_items return current_items
def build_dataloader(cfg, is_train=True): def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg) dataset = DATASETS.get(cfg.name)(cfg)
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import logging import logging
import paddle import paddle
from paddle import ParallelEnv, DataParallel from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
...@@ -19,7 +19,8 @@ class Trainer: ...@@ -19,7 +19,8 @@ class Trainer:
self.train_dataloader = build_dataloader(cfg.dataset.train) self.train_dataloader = build_dataloader(cfg.dataset.train)
if 'lr_scheduler' in cfg.optimizer: if 'lr_scheduler' in cfg.optimizer:
cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader) cfg.optimizer.lr_scheduler.step_per_epoch = len(
self.train_dataloader)
# build model # build model
self.model = build_model(cfg) self.model = build_model(cfg)
...@@ -50,7 +51,8 @@ class Trainer: ...@@ -50,7 +51,8 @@ class Trainer:
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name, DataParallel(net, strategy)) setattr(self.model, 'net' + name,
paddle.DataParallel(net, strategy))
def train(self): def train(self):
...@@ -74,14 +76,17 @@ class Trainer: ...@@ -74,14 +76,17 @@ class Trainer:
self.visual('visual_train') self.visual('visual_train')
step_start_time = time.time() step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0: if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1) self.save(epoch, 'weight', keep=-1)
self.save(epoch) self.save(epoch)
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False) self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
...@@ -105,7 +110,8 @@ class Trainer: ...@@ -105,7 +110,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader))) self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
...@@ -143,7 +149,8 @@ class Trainer: ...@@ -143,7 +149,8 @@ class Trainer:
makedirs(os.path.join(self.output_dir, results_dir)) makedirs(os.path.join(self.output_dir, results_dir))
for label, image in visual_results.items(): for label, image in visual_results.items():
image_numpy = tensor2img(image) image_numpy = tensor2img(image)
img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label)) img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label))
save_image(image_numpy, img_path) save_image(image_numpy, img_path)
def save(self, epoch, name='checkpoint', keep=1): def save(self, epoch, name='checkpoint', keep=1):
...@@ -175,8 +182,8 @@ class Trainer: ...@@ -175,8 +182,8 @@ class Trainer:
if keep > 0: if keep > 0:
try: try:
checkpoint_name_to_be_removed = os.path.join(self.output_dir, checkpoint_name_to_be_removed = os.path.join(
'epoch_%s_%s.pkl' % (epoch - keep, name)) self.output_dir, 'epoch_%s_%s.pkl' % (epoch - keep, name))
if os.path.exists(checkpoint_name_to_be_removed): if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed) os.remove(checkpoint_name_to_be_removed)
...@@ -205,4 +212,3 @@ class Trainer: ...@@ -205,4 +212,3 @@ class Trainer:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name]) net.set_dict(state_dicts['net' + name])
\ No newline at end of file
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..solver.lr_scheduler import build_lr_scheduler
class BaseModel(ABC): class BaseModel(ABC):
...@@ -16,7 +17,6 @@ class BaseModel(ABC): ...@@ -16,7 +17,6 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights. -- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the BaseModel class. """Initialize the BaseModel class.
...@@ -33,7 +33,9 @@ class BaseModel(ABC): ...@@ -33,7 +33,9 @@ class BaseModel(ABC):
""" """
self.opt = opt self.opt = opt
self.isTrain = opt.isTrain self.isTrain = opt.isTrain
self.save_dir = os.path.join(opt.output_dir, opt.model.name) # save all the checkpoints to save_dir self.save_dir = os.path.join(
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
self.loss_names = [] self.loss_names = []
self.model_names = [] self.model_names = []
...@@ -75,6 +77,8 @@ class BaseModel(ABC): ...@@ -75,6 +77,8 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
pass pass
def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler)
def eval(self): def eval(self):
"""Make models eval mode during test time""" """Make models eval mode during test time"""
...@@ -114,10 +118,11 @@ class BaseModel(ABC): ...@@ -114,10 +118,11 @@ class BaseModel(ABC):
errors_ret = OrderedDict() errors_ret = OrderedDict()
for name in self.loss_names: for name in self.loss_names:
if isinstance(name, str): if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret return errors_ret
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters: Parameters:
......
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel): ...@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
...@@ -32,7 +31,9 @@ class CycleGANModel(BaseModel): ...@@ -32,7 +31,9 @@ class CycleGANModel(BaseModel):
""" """
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] self.loss_names = [
'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B'] visual_names_B = ['real_B', 'fake_A', 'rec_B']
...@@ -62,7 +63,8 @@ class CycleGANModel(BaseModel): ...@@ -62,7 +63,8 @@ class CycleGANModel(BaseModel):
if self.isTrain: if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) assert (
opt.dataset.train.input_nc == opt.dataset.train.output_nc)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) self.fake_A_pool = ImagePool(opt.dataset.train.pool_size)
# create image buffer to store previously generated images # create image buffer to store previously generated images
...@@ -72,8 +74,17 @@ class CycleGANModel(BaseModel): ...@@ -72,8 +74,17 @@ class CycleGANModel(BaseModel):
self.criterionCycle = paddle.nn.L1Loss() self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) self.build_lr_scheduler()
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG_A.parameters() +
self.netG_B.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters() +
self.netD_B.parameters())
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
...@@ -107,7 +118,6 @@ class CycleGANModel(BaseModel): ...@@ -107,7 +118,6 @@ class CycleGANModel(BaseModel):
elif 'B_paths' in input: elif 'B_paths' in input:
self.image_paths = input['B_paths'] self.image_paths = input['B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'): if hasattr(self, 'real_A'):
...@@ -118,7 +128,6 @@ class CycleGANModel(BaseModel): ...@@ -118,7 +128,6 @@ class CycleGANModel(BaseModel):
self.fake_A = self.netG_B(self.real_B) # G_B(B) self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
...@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel): ...@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
if lambda_idt > 0: if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B|| # G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B) self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A|| # G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A) self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt
else: else:
self.loss_idt_A = 0 self.loss_idt_A = 0
self.loss_idt_B = 0 self.loss_idt_B = 0
...@@ -179,9 +190,11 @@ class CycleGANModel(BaseModel): ...@@ -179,9 +190,11 @@ class CycleGANModel(BaseModel):
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B|| # Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
...@@ -218,4 +231,3 @@ class CycleGANModel(BaseModel): ...@@ -218,4 +231,3 @@ class CycleGANModel(BaseModel):
self.backward_D_B() self.backward_D_B()
# update D_A and D_B's weights # update D_A and D_B's weights
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B)
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel): ...@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
""" """
def __init__(self, opt): def __init__(self, opt):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
...@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel): ...@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
if self.isTrain: if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator) self.netD = build_discriminator(opt.model.discriminator)
if self.isTrain: if self.isTrain:
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
# build optimizers # build optimizers
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) self.build_lr_scheduler()
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD.parameters())
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
...@@ -76,7 +81,6 @@ class Pix2PixModel(BaseModel): ...@@ -76,7 +81,6 @@ class Pix2PixModel(BaseModel):
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A) self.fake_B = self.netG(self.real_A) # G(A)
...@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel): ...@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
pred_fake = self.netD(fake_AB) pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B # Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 self.loss_G_L1 = self.criterionL1(self.fake_B,
self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1
......
...@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg): ...@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
# TODO: add more learning rate scheduler # TODO: add more learning rate scheduler
if name == 'linear': if name == 'linear':
return LinearDecay(**cfg)
def lambda_rule(epoch):
lr_l = 1.0 - max(
0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1)
return lr_l
scheduler = paddle.optimizer.lr_scheduler.LambdaLR(
cfg.learning_rate, lr_lambda=lambda_rule)
return scheduler
else: else:
raise NotImplementedError raise NotImplementedError
class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay): # paddle.optimizer.lr_scheduler
def __init__(self, learning_rate, step_per_epoch, start_epoch, decay_epochs): class LinearDecay(paddle.optimizer.lr_scheduler._LRScheduler):
def __init__(self, learning_rate, step_per_epoch, start_epoch,
decay_epochs):
super(LinearDecay, self).__init__() super(LinearDecay, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.start_epoch = start_epoch self.start_epoch = start_epoch
...@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay ...@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
def step(self): def step(self):
cur_epoch = int(self.step_num // self.step_per_epoch) cur_epoch = int(self.step_num // self.step_per_epoch)
decay_rate = 1.0 - max(0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1) decay_rate = 1.0 - max(
0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1)
return self.create_lr_var(decay_rate * self.learning_rate) return self.create_lr_var(decay_rate * self.learning_rate)
...@@ -4,13 +4,11 @@ import paddle ...@@ -4,13 +4,11 @@ import paddle
from .lr_scheduler import build_lr_scheduler from .lr_scheduler import build_lr_scheduler
def build_optimizer(cfg, parameter_list=None): def build_optimizer(cfg, lr_scheduler, parameter_list=None):
cfg_copy = copy.deepcopy(cfg) cfg_copy = copy.deepcopy(cfg)
lr_scheduler_cfg = cfg_copy.pop('lr_scheduler', None)
lr_scheduler = build_lr_scheduler(lr_scheduler_cfg)
opt_name = cfg_copy.pop('name') opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler, parameters=parameter_list, **cfg_copy) return getattr(paddle.optimizer, opt_name)(lr_scheduler,
parameters=parameter_list,
**cfg_copy)
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
import sys import sys
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
def setup_logger(output=None, name="ppgan"): def setup_logger(output=None, name="ppgan"):
...@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"): ...@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
logger.propagate = False logger.propagate = False
plain_formatter = logging.Formatter( plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
) datefmt="%m/%d %H:%M:%S")
# stdout logging: master only # stdout logging: master only
local_rank = ParallelEnv().local_rank local_rank = ParallelEnv().local_rank
if local_rank == 0: if local_rank == 0:
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import time import time
import paddle import paddle
from paddle import ParallelEnv from paddle.distributed import ParallelEnv
from .logger import setup_logger from .logger import setup_logger
...@@ -12,7 +12,8 @@ def setup(args, cfg): ...@@ -12,7 +12,8 @@ def setup(args, cfg):
cfg.isTrain = False cfg.isTrain = False
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, str(cfg.model.name) + cfg.timestamp) cfg.output_dir = os.path.join(cfg.output_dir,
str(cfg.model.name) + cfg.timestamp)
logger = setup_logger(cfg.output_dir) logger = setup_logger(cfg.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册