diff --git a/applications/DAIN/predict.py b/applications/DAIN/predict.py index 15e23fa010ed65041d4b54c4467978b06bb8e187..6c6e5234fa1584e3704f8873a3cfa658ad689e71 100644 --- a/applications/DAIN/predict.py +++ b/applications/DAIN/predict.py @@ -11,7 +11,7 @@ from imageio import imread, imsave import cv2 import paddle.fluid as fluid -from paddle.incubate.hapi.download import get_path_from_url +from paddle.utils.download import get_path_from_url import networks from util import * @@ -19,6 +19,7 @@ from my_args import parser DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar' + def infer_engine(model_dir, run_mode='fluid', batch_size=1, @@ -91,7 +92,6 @@ class VideoFrameInterp(object): self.exe, self.program, self.fetch_targets = executor(model_path, use_gpu=use_gpu) - def run(self): frame_path_input = os.path.join(self.output_path, 'frames-input') frame_path_interpolated = os.path.join(self.output_path, @@ -272,7 +272,7 @@ class VideoFrameInterp(object): os.remove(video_pattern_output) frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output, r2) - + return frame_pattern_combined, video_pattern_output diff --git a/applications/DeOldify/predict.py b/applications/DeOldify/predict.py index 1f73fd45b499599bf5935e8d6aa85bd8a1fdddff..df77d48de177f8cf2af06d2e21470d8a6e40ceb0 100644 --- a/applications/DeOldify/predict.py +++ b/applications/DeOldify/predict.py @@ -15,15 +15,19 @@ from PIL import Image from tqdm import tqdm from paddle import fluid from model import build_model -from paddle.incubate.hapi.download import get_path_from_url +from paddle.utils.download import get_path_from_url parser = argparse.ArgumentParser(description='DeOldify') -parser.add_argument('--input', type=str, default='none', help='Input video') -parser.add_argument('--output', type=str, default='output', help='output dir') -parser.add_argument('--weight_path', type=str, default='none', help='Path to the reference image directory') +parser.add_argument('--input', type=str, default='none', help='Input video') +parser.add_argument('--output', type=str, default='output', help='output dir') +parser.add_argument('--weight_path', + type=str, + default='none', + help='Path to the reference image directory') DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' + def frames_to_video_ffmpeg(framepath, videopath, r): ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] cmd = ffmpeg + [ @@ -56,9 +60,9 @@ class DeOldifyPredictor(): def norm(self, img, render_factor=32, render_base=16): target_size = render_factor * render_base img = img.resize((target_size, target_size), resample=Image.BILINEAR) - + img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0 - + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -69,13 +73,13 @@ class DeOldifyPredictor(): def denorm(self, img): img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) - + img *= img_std img += img_mean img = img.transpose((1, 2, 0)) return (img * 255).clip(0, 255).astype('uint8') - + def post_process(self, raw_color, orig): color_np = np.asarray(raw_color) orig_np = np.asarray(orig) @@ -86,11 +90,11 @@ class DeOldifyPredictor(): final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) final = Image.fromarray(final) return final - + def run_single(self, img_path): ori_img = Image.open(img_path).convert('LA').convert('RGB') img = self.norm(ori_img) - x = paddle.to_tensor(img[np.newaxis,...]) + x = paddle.to_tensor(img[np.newaxis, ...]) out = self.model(x) pred_img = self.denorm(out.numpy()[0]) @@ -118,20 +122,20 @@ class DeOldifyPredictor(): frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) - for frame in tqdm(frames): pred_img = self.run_single(frame) frame_name = os.path.basename(frame) pred_img.save(os.path.join(pred_frame_path, frame_name)) - + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') - - vid_out_path = os.path.join(output_path, '{}_deoldify_out.mp4'.format(base_name)) - frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps))) - return frame_pattern_combined, vid_out_path + vid_out_path = os.path.join(output_path, + '{}_deoldify_out.mp4'.format(base_name)) + frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, + str(int(fps))) + return frame_pattern_combined, vid_out_path def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): @@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): if ss is not None and t is not None and r is not None: cmd = ffmpeg + [ - ' -ss ', - ss, - ' -t ', - t, - ' -i ', - vid_path, - ' -r ', - r, - - ' -qscale:v ', - ' 0.1 ', - ' -start_number ', - ' 0 ', - - outformat + ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', + ' 0.1 ', ' -start_number ', ' 0 ', outformat ] else: cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] @@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): return out_full_path -if __name__=='__main__': +if __name__ == '__main__': paddle.enable_imperative() args = parser.parse_args() - predictor = DeOldifyPredictor(args.input, args.output, weight_path=args.weight_path) + predictor = DeOldifyPredictor(args.input, + args.output, + weight_path=args.weight_path) frames_path, temp_video_path = predictor.run() - print('output video path:', temp_video_path) \ No newline at end of file + print('output video path:', temp_video_path) diff --git a/applications/DeepRemaster/predict.py b/applications/DeepRemaster/predict.py index 7d2fbcfdb9d8793f8d6b0e23ea0e5aca4bc80243..3ad54b31eff7bb14f630e0c1e66970f96a12e173 100644 --- a/applications/DeepRemaster/predict.py +++ b/applications/DeepRemaster/predict.py @@ -15,195 +15,235 @@ import argparse import subprocess import utils from remasternet import NetworkR, NetworkC -from paddle.incubate.hapi.download import get_path_from_url +from paddle.utils.download import get_path_from_url DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' parser = argparse.ArgumentParser(description='Remastering') -parser.add_argument('--input', type=str, default=None, help='Input video') -parser.add_argument('--output', type=str, default='output', help='output dir') -parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory') -parser.add_argument('--colorization', action='store_true', default=False, help='Remaster without colorization') -parser.add_argument('--mindim', type=int, default='360', help='Length of minimum image edges') +parser.add_argument('--input', type=str, default=None, help='Input video') +parser.add_argument('--output', type=str, default='output', help='output dir') +parser.add_argument('--reference_dir', + type=str, + default=None, + help='Path to the reference image directory') +parser.add_argument('--colorization', + action='store_true', + default=False, + help='Remaster without colorization') +parser.add_argument('--mindim', + type=int, + default='360', + help='Length of minimum image edges') class DeepReasterPredictor: - def __init__(self, input, output, weight_path=None, colorization=False, reference_dir=None, mindim=360): - self.input = input - self.output = os.path.join(output, 'DeepRemaster') - self.colorization = colorization - self.reference_dir = reference_dir - self.mindim = mindim - - if weight_path is None: - weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) - - state_dict, _ = paddle.load(weight_path) - - self.modelR = NetworkR() - self.modelR.load_dict(state_dict['modelR']) - self.modelR.eval() - if colorization: - self.modelC = NetworkC() - self.modelC.load_dict(state_dict['modelC']) - self.modelC.eval() - - - def run(self): - outputdir = self.output - outputdir_in = os.path.join(outputdir, 'input/') - os.makedirs( outputdir_in, exist_ok=True ) - outputdir_out = os.path.join(outputdir, 'output/') - os.makedirs( outputdir_out, exist_ok=True ) - - # Prepare reference images - if self.colorization: - if self.reference_dir is not None: - import glob - ext_list = ['png','jpg','bmp'] - reference_files = [] - for ext in ext_list: - reference_files += glob.glob( self.reference_dir+'/*.'+ext, recursive=True ) - aspect_mean = 0 - minedge_dim = 256 - refs = [] - for v in reference_files: - refimg = Image.open( v ).convert('RGB') - w, h = refimg.size - aspect_mean += w/h - refs.append( refimg ) - aspect_mean /= len(reference_files) - target_w = int(256*aspect_mean) if aspect_mean>1 else 256 - target_h = 256 if aspect_mean>=1 else int(256/aspect_mean) - - refimgs = [] - for i, v in enumerate(refs): - refimg = utils.addMergin( v, target_w=target_w, target_h=target_h ) - refimg = np.array(refimg).astype('float32').transpose(2, 0, 1) / 255.0 - refimgs.append(refimg) - refimgs = paddle.to_tensor(np.array(refimgs).astype('float32')) - - refimgs = paddle.unsqueeze(refimgs, 0) - - # Load video - cap = cv2.VideoCapture( self.input ) - nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) - v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) - minwh = min(v_w,v_h) - scale = 1 - if minwh != self.mindim: - scale = self.mindim / minwh - - t_w = round(v_w*scale/16.)*16 - t_h = round(v_h*scale/16.)*16 - fps = cap.get(cv2.CAP_PROP_FPS) - pbar = tqdm(total=nframes) - block = 5 - - # Process - with paddle.no_grad(): - it = 0 - while True: - frame_pos = it*block - if frame_pos >= nframes: - break - cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos) - if block >= nframes-frame_pos: - proc_g = nframes-frame_pos - else: - proc_g = block - - input = None - gtC = None - for i in range(proc_g): - index = frame_pos + i - _, frame = cap.read() - frame = cv2.resize(frame, (t_w, t_h)) - nchannels = frame.shape[2] - if nchannels == 1 or self.colorization: - frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - cv2.imwrite(outputdir_in+'%07d.png'%index, frame_l) - frame_l = paddle.to_tensor(frame_l.astype('float32')) - frame_l = paddle.reshape(frame_l, [frame_l.shape[0], frame_l.shape[1], 1]) - frame_l = paddle.transpose(frame_l, [2, 0, 1]) - frame_l /= 255. - - frame_l = paddle.reshape(frame_l, [1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) - elif nchannels == 3: - cv2.imwrite(outputdir_in+'%07d.png'%index, frame) - frame = frame[:,:,::-1] ## BGR -> RGB - frame_l, frame_ab = utils.convertRGB2LABTensor( frame ) - frame_l = frame_l.transpose([2, 0, 1]) - frame_ab = frame_ab.transpose([2, 0, 1]) - frame_l = frame_l.reshape([1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]]) - frame_ab = frame_ab.reshape([1, frame_ab.shape[0], 1, frame_ab.shape[1], frame_ab.shape[2]]) - - if input is not None: - paddle.concat( (input, frame_l), 2 ) - - input = frame_l if i==0 else paddle.concat( (input, frame_l), 2 ) - if nchannels==3 and not self.colorization: - gtC = frame_ab if i==0 else paddle.concat( (gtC, frame_ab), 2 ) - - input = paddle.to_tensor(input) - - - output_l = self.modelR( input ) # [B, C, T, H, W] - - # Save restoration output without colorization when using the option [--disable_colorization] - if not self.colorization: - for i in range( proc_g ): - index = frame_pos + i - if nchannels==3: - out_l = output_l.detach()[0,:,i] - out_ab = gtC[0,:,i] - - out = paddle.concat((out_l, out_ab),axis=0).detach().numpy().transpose((1, 2, 0)) - out = Image.fromarray( np.uint8( utils.convertLAB2RGB( out )*255 ) ) - out.save( outputdir_out+'%07d.png'%(index) ) - else: - raise ValueError('channels of imag3 must be 3!') - - # Perform colorization - else: - if self.reference_dir is None: - output_ab = self.modelC( output_l ) - else: - output_ab = self.modelC( output_l, refimgs ) - output_l = output_l.detach() - output_ab = output_ab.detach() - - - for i in range( proc_g ): - index = frame_pos + i - out_l = output_l[0,:,i,:,:] - out_c = output_ab[0,:,i,:,:] - output = paddle.concat((out_l, out_c), axis=0).numpy().transpose((1, 2, 0)) - output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) ) - output.save( outputdir_out+'%07d.png'%index ) - - it = it + 1 - pbar.update(proc_g) - - # Save result videos - outfile = os.path.join(outputdir, self.input.split('/')[-1].split('.')[0]) - cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (fps, outputdir_in, fps, outfile ) - subprocess.call( cmd, shell=True ) - cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (fps, outputdir_out, fps, outfile ) - subprocess.call( cmd, shell=True ) - cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( outfile, outfile, outfile ) - subprocess.call( cmd, shell=True ) - - cap.release() - pbar.close() - return outputdir_out, '%s_out.mp4' % outfile + def __init__(self, + input, + output, + weight_path=None, + colorization=False, + reference_dir=None, + mindim=360): + self.input = input + self.output = os.path.join(output, 'DeepRemaster') + self.colorization = colorization + self.reference_dir = reference_dir + self.mindim = mindim + + if weight_path is None: + weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) + + state_dict, _ = paddle.load(weight_path) + + self.modelR = NetworkR() + self.modelR.load_dict(state_dict['modelR']) + self.modelR.eval() + if colorization: + self.modelC = NetworkC() + self.modelC.load_dict(state_dict['modelC']) + self.modelC.eval() + + def run(self): + outputdir = self.output + outputdir_in = os.path.join(outputdir, 'input/') + os.makedirs(outputdir_in, exist_ok=True) + outputdir_out = os.path.join(outputdir, 'output/') + os.makedirs(outputdir_out, exist_ok=True) + + # Prepare reference images + if self.colorization: + if self.reference_dir is not None: + import glob + ext_list = ['png', 'jpg', 'bmp'] + reference_files = [] + for ext in ext_list: + reference_files += glob.glob(self.reference_dir + '/*.' + + ext, + recursive=True) + aspect_mean = 0 + minedge_dim = 256 + refs = [] + for v in reference_files: + refimg = Image.open(v).convert('RGB') + w, h = refimg.size + aspect_mean += w / h + refs.append(refimg) + aspect_mean /= len(reference_files) + target_w = int(256 * aspect_mean) if aspect_mean > 1 else 256 + target_h = 256 if aspect_mean >= 1 else int(256 / aspect_mean) + + refimgs = [] + for i, v in enumerate(refs): + refimg = utils.addMergin(v, + target_w=target_w, + target_h=target_h) + refimg = np.array(refimg).astype('float32').transpose( + 2, 0, 1) / 255.0 + refimgs.append(refimg) + refimgs = paddle.to_tensor(np.array(refimgs).astype('float32')) + + refimgs = paddle.unsqueeze(refimgs, 0) + + # Load video + cap = cv2.VideoCapture(self.input) + nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + minwh = min(v_w, v_h) + scale = 1 + if minwh != self.mindim: + scale = self.mindim / minwh + + t_w = round(v_w * scale / 16.) * 16 + t_h = round(v_h * scale / 16.) * 16 + fps = cap.get(cv2.CAP_PROP_FPS) + pbar = tqdm(total=nframes) + block = 5 + + # Process + with paddle.no_grad(): + it = 0 + while True: + frame_pos = it * block + if frame_pos >= nframes: + break + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos) + if block >= nframes - frame_pos: + proc_g = nframes - frame_pos + else: + proc_g = block + + input = None + gtC = None + for i in range(proc_g): + index = frame_pos + i + _, frame = cap.read() + frame = cv2.resize(frame, (t_w, t_h)) + nchannels = frame.shape[2] + if nchannels == 1 or self.colorization: + frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + cv2.imwrite(outputdir_in + '%07d.png' % index, frame_l) + frame_l = paddle.to_tensor(frame_l.astype('float32')) + frame_l = paddle.reshape( + frame_l, [frame_l.shape[0], frame_l.shape[1], 1]) + frame_l = paddle.transpose(frame_l, [2, 0, 1]) + frame_l /= 255. + + frame_l = paddle.reshape(frame_l, [ + 1, frame_l.shape[0], 1, frame_l.shape[1], + frame_l.shape[2] + ]) + elif nchannels == 3: + cv2.imwrite(outputdir_in + '%07d.png' % index, frame) + frame = frame[:, :, ::-1] ## BGR -> RGB + frame_l, frame_ab = utils.convertRGB2LABTensor(frame) + frame_l = frame_l.transpose([2, 0, 1]) + frame_ab = frame_ab.transpose([2, 0, 1]) + frame_l = frame_l.reshape([ + 1, frame_l.shape[0], 1, frame_l.shape[1], + frame_l.shape[2] + ]) + frame_ab = frame_ab.reshape([ + 1, frame_ab.shape[0], 1, frame_ab.shape[1], + frame_ab.shape[2] + ]) + + if input is not None: + paddle.concat((input, frame_l), 2) + + input = frame_l if i == 0 else paddle.concat( + (input, frame_l), 2) + if nchannels == 3 and not self.colorization: + gtC = frame_ab if i == 0 else paddle.concat( + (gtC, frame_ab), 2) + + input = paddle.to_tensor(input) + + output_l = self.modelR(input) # [B, C, T, H, W] + + # Save restoration output without colorization when using the option [--disable_colorization] + if not self.colorization: + for i in range(proc_g): + index = frame_pos + i + if nchannels == 3: + out_l = output_l.detach()[0, :, i] + out_ab = gtC[0, :, i] + + out = paddle.concat( + (out_l, out_ab), + axis=0).detach().numpy().transpose((1, 2, 0)) + out = Image.fromarray( + np.uint8(utils.convertLAB2RGB(out) * 255)) + out.save(outputdir_out + '%07d.png' % (index)) + else: + raise ValueError('channels of imag3 must be 3!') + + # Perform colorization + else: + if self.reference_dir is None: + output_ab = self.modelC(output_l) + else: + output_ab = self.modelC(output_l, refimgs) + output_l = output_l.detach() + output_ab = output_ab.detach() + + for i in range(proc_g): + index = frame_pos + i + out_l = output_l[0, :, i, :, :] + out_c = output_ab[0, :, i, :, :] + output = paddle.concat( + (out_l, out_c), axis=0).numpy().transpose((1, 2, 0)) + output = Image.fromarray( + np.uint8(utils.convertLAB2RGB(output) * 255)) + output.save(outputdir_out + '%07d.png' % index) + + it = it + 1 + pbar.update(proc_g) + + # Save result videos + outfile = os.path.join(outputdir, + self.input.split('/')[-1].split('.')[0]) + cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % ( + fps, outputdir_in, fps, outfile) + subprocess.call(cmd, shell=True) + cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % ( + fps, outputdir_out, fps, outfile) + subprocess.call(cmd, shell=True) + cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( + outfile, outfile, outfile) + subprocess.call(cmd, shell=True) + + cap.release() + pbar.close() + return outputdir_out, '%s_out.mp4' % outfile if __name__ == "__main__": - args = parser.parse_args() - paddle.disable_static() - predictor = DeepReasterPredictor(args.input, args.output, colorization=args.colorization, - reference_dir=args.reference_dir, mindim=args.mindim) - predictor.run() - \ No newline at end of file + args = parser.parse_args() + paddle.disable_static() + predictor = DeepReasterPredictor(args.input, + args.output, + colorization=args.colorization, + reference_dir=args.reference_dir, + mindim=args.mindim) + predictor.run() diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index a1cb1b98415f91f05d6333249f2468542fc2b6b2..4b888eaebecb3d0bf30ed634f9fb9eaac57b9ff7 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -28,30 +28,29 @@ import paddle.fluid as fluid import cv2 from data import EDVRDataset -from paddle.incubate.hapi.download import get_path_from_url +from paddle.utils.download import get_path_from_url EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - '--input', - type=str, - default=None, - help='input video path') - parser.add_argument( - '--output', - type=str, - default='output', - help='output path') - parser.add_argument( - '--weight_path', - type=str, - default=None, - help='weight path') + parser.add_argument('--input', + type=str, + default=None, + help='input video path') + parser.add_argument('--output', + type=str, + default='output', + help='output path') + parser.add_argument('--weight_path', + type=str, + default=None, + help='weight path') args = parser.parse_args() return args + def get_img(pred): print('pred shape', pred.shape) pred = pred.squeeze() @@ -59,10 +58,11 @@ def get_img(pred): pred = pred * 255 pred = pred.round() pred = pred.astype('uint8') - pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc - pred = pred[:, :, ::-1] # rgb -> bgr + pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc + pred = pred[:, :, ::-1] # rgb -> bgr return pred + def save_img(img, framename): dirname = os.path.dirname(framename) if not os.path.exists(dirname): @@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): if ss is not None and t is not None and r is not None: cmd = ffmpeg + [ - ' -ss ', - ss, - ' -t ', - t, - ' -i ', - vid_path, - ' -r ', - r, - ' -qscale:v ', - ' 0.1 ', - ' -start_number ', - ' 0 ', - outformat + ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', + ' 0.1 ', ' -start_number ', ' 0 ', outformat ] else: cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] @@ -134,20 +123,21 @@ class EDVRPredictor: self.input = input self.output = os.path.join(output, 'EDVR') - place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() + place = fluid.CUDAPlace( + 0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() self.exe = fluid.Executor(place) if weight_path is None: weight_path = get_path_from_url(EDVR_weight_url, cur_path) - + print(weight_path) model_filename = 'EDVR_model.pdmodel' params_filename = 'EDVR_params.pdparams' - - out = fluid.io.load_inference_model(dirname=weight_path, - model_filename=model_filename, - params_filename=params_filename, + + out = fluid.io.load_inference_model(dirname=weight_path, + model_filename=model_filename, + params_filename=params_filename, executor=self.exe) self.infer_prog, self.feed_list, self.fetch_list = out @@ -176,16 +166,19 @@ class EDVRPredictor: cur_time = time.time() for infer_iter, data in enumerate(dataset): data_feed_in = [data[0]] - - infer_outs = self.exe.run(self.infer_prog, - fetch_list=self.fetch_list, - feed={self.feed_list[0]:np.array(data_feed_in)}) + + infer_outs = self.exe.run( + self.infer_prog, + fetch_list=self.fetch_list, + feed={self.feed_list[0]: np.array(data_feed_in)}) infer_result_list = [item for item in infer_outs] frame_path = data[1] - + img_i = get_img(infer_result_list[0]) - save_img(img_i, os.path.join(pred_frame_path, os.path.basename(frame_path))) + save_img( + img_i, + os.path.join(pred_frame_path, os.path.basename(frame_path))) prev_time = cur_time cur_time = time.time() @@ -194,13 +187,15 @@ class EDVRPredictor: print('Processed {} samples'.format(infer_iter + 1)) frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') - vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) - frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps))) + vid_out_path = os.path.join(self.output, + '{}_edvr_out.mp4'.format(base_name)) + frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, + str(int(fps))) return frame_pattern_combined, vid_out_path if __name__ == "__main__": + args = parse_args() predictor = EDVRPredictor(args.input, args.output, args.weight_path) predictor.run() - diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index ea0d5ce7f65b3c06cce51fca283649c850b6a4bb..c4facd8bc817f5c3ffb450feae0c221880aa5dff 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -60,11 +60,12 @@ dataset: optimizer: name: Adam beta1: 0.5 - lr_scheduler: - name: linear - learning_rate: 0.0002 - start_epoch: 100 - decay_epochs: 100 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 log_config: interval: 100 @@ -72,4 +73,3 @@ log_config: snapshot_config: interval: 5 - diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 72c7cf6ccf3a9b058914498a25c8165400d3cb4e..1ea5c6d1687c35197a1470833b33cdefcf0ba5ee 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -59,11 +59,12 @@ dataset: optimizer: name: Adam beta1: 0.5 - lr_scheduler: - name: linear - learning_rate: 0.0002 - start_epoch: 100 - decay_epochs: 100 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 log_config: interval: 100 @@ -71,4 +72,3 @@ log_config: snapshot_config: interval: 5 - diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index c641b58445dcfa97aa94937ab93b75a14d62110a..06577f7f1a50301431832b541b17c78802e35077 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -25,7 +25,7 @@ dataset: train: name: PairedDataset dataroot: data/cityscapes - num_workers: 0 + num_workers: 4 phase: train max_dataset_size: inf direction: BtoA @@ -57,11 +57,12 @@ dataset: optimizer: name: Adam beta1: 0.5 - lr_scheduler: - name: linear - learning_rate: 0.0002 - start_epoch: 100 - decay_epochs: 100 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 log_config: interval: 100 @@ -69,4 +70,3 @@ log_config: snapshot_config: interval: 5 - diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 387f16bd2de43cd150a852b238057850607911be..a64b57a8c5e7bd80c71109bef58b8d8bf17fffff 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -56,11 +56,12 @@ dataset: optimizer: name: Adam beta1: 0.5 - lr_scheduler: - name: linear - learning_rate: 0.0004 - start_epoch: 100 - decay_epochs: 100 + +lr_scheduler: + name: linear + learning_rate: 0.0004 + start_epoch: 100 + decay_epochs: 100 log_config: interval: 100 @@ -68,4 +69,3 @@ log_config: snapshot_config: interval: 5 - diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 06a8403f97723664e5fe330e7bafdd0d3171aa89..ede78386fdd09d6f67d797d658c494892c316fd9 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -56,11 +56,12 @@ dataset: optimizer: name: Adam beta1: 0.5 - lr_scheduler: - name: linear - learning_rate: 0.0002 - start_epoch: 100 - decay_epochs: 100 + +lr_scheduler: + name: linear + learning_rate: 0.0002 + start_epoch: 100 + decay_epochs: 100 log_config: interval: 100 diff --git a/ppgan/datasets/base_dataset.py b/ppgan/datasets/base_dataset.py index b382651d9c9cfc742bc3ddb192c364509abfb630..87e996925477c5fa096df48779327397ce22873b 100644 --- a/ppgan/datasets/base_dataset.py +++ b/ppgan/datasets/base_dataset.py @@ -6,7 +6,7 @@ from paddle.io import Dataset from PIL import Image import cv2 -import paddle.incubate.hapi.vision.transforms as transforms +import paddle.vision.transforms as transforms from .transforms import transforms as T from abc import ABC, abstractmethod @@ -14,7 +14,6 @@ from abc import ABC, abstractmethod class BaseDataset(Dataset, ABC): """This class is an abstract base class (ABC) for datasets. """ - def __init__(self, cfg): """Initialize the class; save the options in the class @@ -60,8 +59,11 @@ def get_params(cfg, size): return {'crop_pos': (x, y), 'flip': flip} - -def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True): +def get_transform(cfg, + params=None, + grayscale=False, + method=cv2.INTER_CUBIC, + convert=True): transform_list = [] if grayscale: print('grayscale not support for now!!!') @@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.RandomHorizontalFlip(1.0)) - + if convert: transform_list += [transforms.Permute(to_rgb=True)] - transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))] + transform_list += [ + transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) + ] return transforms.Compose(transform_list) diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index c26137a755e11d65709547617f7683f236883397..62b5346795c1383d683926e46064b97ea8a14aee 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -3,12 +3,11 @@ import paddle import numbers import numpy as np from multiprocessing import Manager -from paddle import ParallelEnv +from paddle.distributed import ParallelEnv -from paddle.incubate.hapi.distributed import DistributedBatchSampler +from paddle.io import DistributedBatchSampler from ..utils.registry import Registry - DATASETS = Registry("DATASETS") @@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset): single_item = dataset[0] self.keys = single_item.keys() - + for k, v in single_item.items(): if not isinstance(v, (numbers.Number, np.ndarray)): setattr(self, k, Manager().dict()) @@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset): def __getitem__(self, index): ori_map = self.dataset[index] - + tmp_list = [] - + for k, v in ori_map.items(): if isinstance(v, (numbers.Number, np.ndarray)): tmp_list.append(v) @@ -60,17 +59,15 @@ class DictDataLoader(): place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \ if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0) - sampler = DistributedBatchSampler( - self.dataset, - batch_size=batch_size, - shuffle=True if is_train else False, - drop_last=True if is_train else False) + sampler = DistributedBatchSampler(self.dataset, + batch_size=batch_size, + shuffle=True if is_train else False, + drop_last=True if is_train else False) - self.dataloader = paddle.io.DataLoader( - self.dataset, - batch_sampler=sampler, - places=place, - num_workers=num_workers) + self.dataloader = paddle.io.DataLoader(self.dataset, + batch_sampler=sampler, + places=place, + num_workers=num_workers) self.batch_size = batch_size @@ -83,7 +80,9 @@ class DictDataLoader(): j = 0 for k in self.dataset.keys: if k in self.dataset.tensor_keys_set: - return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data + return_dict[k] = data[j] if isinstance(data, + (list, + tuple)) else data j += 1 else: return_dict[k] = self.get_items_by_indexs(k, data[-1]) @@ -104,13 +103,12 @@ class DictDataLoader(): return current_items - def build_dataloader(cfg, is_train=True): dataset = DATASETS.get(cfg.name)(cfg) - + batch_size = cfg.get('batch_size', 1) num_workers = cfg.get('num_workers', 0) dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) - return dataloader \ No newline at end of file + return dataloader diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 82e82f6d7efdf2f51fb709d960d8fc912cb61e9b..f7f456962e61c4358853bcf2365006fd340e68f6 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -4,7 +4,7 @@ import time import logging import paddle -from paddle import ParallelEnv, DataParallel +from paddle.distributed import ParallelEnv from ..datasets.builder import build_dataloader from ..models.builder import build_model @@ -17,10 +17,11 @@ class Trainer: # build train dataloader self.train_dataloader = build_dataloader(cfg.dataset.train) - + if 'lr_scheduler' in cfg.optimizer: - cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader) - + cfg.optimizer.lr_scheduler.step_per_epoch = len( + self.train_dataloader) + # build model self.model = build_model(cfg) # multiple gpus prepare @@ -44,16 +45,17 @@ class Trainer: # time count self.time_count = {} - + def distributed_data_parallel(self): strategy = paddle.prepare_context() for name in self.model.model_names: if isinstance(name, str): net = getattr(self.model, 'net' + name) - setattr(self.model, 'net' + name, DataParallel(net, strategy)) + setattr(self.model, 'net' + name, + paddle.DataParallel(net, strategy)) def train(self): - + for epoch in range(self.start_epoch, self.epochs): self.current_epoch = epoch start_time = step_start_time = time.time() @@ -64,24 +66,27 @@ class Trainer: # data input should be dict self.model.set_input(data) self.model.optimize_parameters() - + self.data_time = data_time - step_start_time self.step_time = time.time() - step_start_time if i % self.log_interval == 0: self.print_log() - + if i % self.visual_interval == 0: self.visual('visual_train') step_start_time = time.time() - self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) + self.logger.info('train one epoch time: {}'.format(time.time() - + start_time)) + self.model.lr_scheduler.step() if epoch % self.weight_interval == 0: self.save(epoch, 'weight', keep=-1) self.save(epoch) def test(self): if not hasattr(self, 'test_dataloader'): - self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False) + self.test_dataloader = build_dataloader(self.cfg.dataset.test, + is_train=False) # data[0]: img, data[1]: img path index # test batch size must be 1 @@ -103,14 +108,15 @@ class Trainer: visual_results.update({name: img_tensor[j]}) self.visual('visual_test', visual_results=visual_results) - + if i % self.log_interval == 0: - self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader))) + self.logger.info('Test iter: [%d/%d]' % + (i, len(self.test_dataloader))) def print_log(self): losses = self.model.get_current_losses() message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id) - + message += '%s: %.6f ' % ('lr', self.current_learning_rate) for k, v in losses.items(): @@ -143,13 +149,14 @@ class Trainer: makedirs(os.path.join(self.output_dir, results_dir)) for label, image in visual_results.items(): image_numpy = tensor2img(image) - img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label)) + img_path = os.path.join(self.output_dir, results_dir, + msg + '%s.png' % (label)) save_image(image_numpy, img_path) def save(self, epoch, name='checkpoint', keep=1): if self.local_rank != 0: return - + assert name in ['checkpoint', 'weight'] state_dicts = {} @@ -175,8 +182,8 @@ class Trainer: if keep > 0: try: - checkpoint_name_to_be_removed = os.path.join(self.output_dir, - 'epoch_%s_%s.pkl' % (epoch - keep, name)) + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'epoch_%s_%s.pkl' % (epoch - keep, name)) if os.path.exists(checkpoint_name_to_be_removed): os.remove(checkpoint_name_to_be_removed) @@ -187,7 +194,7 @@ class Trainer: state_dicts = load(checkpoint_path) if state_dicts.get('epoch', None) is not None: self.start_epoch = state_dicts['epoch'] + 1 - + for name in self.model.model_names: if isinstance(name, str): net = getattr(self.model, 'net' + name) @@ -200,9 +207,8 @@ class Trainer: def load(self, weight_path): state_dicts = load(weight_path) - + for name in self.model.model_names: if isinstance(name, str): net = getattr(self.model, 'net' + name) net.set_dict(state_dicts['net' + name]) - \ No newline at end of file diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index 6d038c77008e32002eff09c46f06f474567c93a9..93720c2c025fc17b2f35294293c825f2f1d0519c 100644 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -5,6 +5,7 @@ import numpy as np from collections import OrderedDict from abc import ABC, abstractmethod +from ..solver.lr_scheduler import build_lr_scheduler class BaseModel(ABC): @@ -16,7 +17,6 @@ class BaseModel(ABC): -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ - def __init__(self, opt): """Initialize the BaseModel class. @@ -33,8 +33,10 @@ class BaseModel(ABC): """ self.opt = opt self.isTrain = opt.isTrain - self.save_dir = os.path.join(opt.output_dir, opt.model.name) # save all the checkpoints to save_dir - + self.save_dir = os.path.join( + opt.output_dir, + opt.model.name) # save all the checkpoints to save_dir + self.loss_names = [] self.model_names = [] self.visual_names = [] @@ -75,6 +77,8 @@ class BaseModel(ABC): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass + def build_lr_scheduler(self): + self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler) def eval(self): """Make models eval mode during test time""" @@ -114,10 +118,11 @@ class BaseModel(ABC): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): - errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + errors_ret[name] = float( + getattr(self, 'loss_' + name) + ) # float(...) works for both scalar tensor and float number return errors_ret - def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: diff --git a/ppgan/models/cycle_gan_model.py b/ppgan/models/cycle_gan_model.py index 25f2547a9c3a7ef7aa231c4fae2608496302b0ee..c6afecdea1a921ec7edb434535b013d55341a7f1 100644 --- a/ppgan/models/cycle_gan_model.py +++ b/ppgan/models/cycle_gan_model.py @@ -1,5 +1,5 @@ import paddle -from paddle import ParallelEnv +from paddle.distributed import ParallelEnv from .base_model import BaseModel from .builder import MODELS @@ -23,7 +23,6 @@ class CycleGANModel(BaseModel): CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf """ - def __init__(self, opt): """Initialize the CycleGAN class. @@ -32,12 +31,14 @@ class CycleGANModel(BaseModel): """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call - self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] + self.loss_names = [ + 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' + ] # specify the images you want to save/display. The training/test scripts will call visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] - # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) + # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_B') visual_names_B.append('idt_A') @@ -62,18 +63,28 @@ class CycleGANModel(BaseModel): if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels - assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc) + assert ( + opt.dataset.train.input_nc == opt.dataset.train.output_nc) # create image buffer to store previously generated images self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) # define loss functions self.criterionGAN = GANLoss(opt.model.gan_mode) - self.criterionCycle = paddle.nn.L1Loss() + self.criterionCycle = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss() - - self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters()) - self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters()) + + self.build_lr_scheduler() + self.optimizer_G = build_optimizer( + opt.optimizer, + self.lr_scheduler, + parameter_list=self.netG_A.parameters() + + self.netG_B.parameters()) + self.optimizer_D = build_optimizer( + opt.optimizer, + self.lr_scheduler, + parameter_list=self.netD_A.parameters() + + self.netD_B.parameters()) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -90,7 +101,7 @@ class CycleGANModel(BaseModel): """ mode = 'train' if self.isTrain else 'test' AtoB = self.opt.dataset[mode].direction == 'AtoB' - + if AtoB: if 'A' in input: self.real_A = paddle.to_tensor(input['A']) @@ -107,17 +118,15 @@ class CycleGANModel(BaseModel): elif 'B_paths' in input: self.image_paths = input['B_paths'] - def forward(self): """Run forward pass; called by both functions and .""" if hasattr(self, 'real_A'): self.fake_B = self.netG_A(self.real_A) # G_A(A) - self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) + self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) if hasattr(self, 'real_B'): self.fake_A = self.netG_B(self.real_B) # G_B(B) - self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) - + self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator @@ -166,10 +175,12 @@ class CycleGANModel(BaseModel): if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) - self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + self.loss_idt_A = self.criterionIdt( + self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) - self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + self.loss_idt_B = self.criterionIdt( + self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 @@ -179,12 +190,14 @@ class CycleGANModel(BaseModel): # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| - self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + self.loss_cycle_A = self.criterionCycle(self.rec_A, + self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| - self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + self.loss_cycle_B = self.criterionCycle(self.rec_B, + self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - + if ParallelEnv().nranks > 1: self.loss_G = self.netG_A.scale_loss(self.loss_G) self.loss_G.backward() @@ -216,6 +229,5 @@ class CycleGANModel(BaseModel): self.backward_D_A() # calculate graidents for D_B self.backward_D_B() - # update D_A and D_B's weights + # update D_A and D_B's weights self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) - diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 50009175843722f9ef7e200dd6c412cbfebb803b..7acecc7d42d0d2ade737092af50ddc3fbfacabb6 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -1,5 +1,5 @@ import paddle -from paddle import ParallelEnv +from paddle.distributed import ParallelEnv from .base_model import BaseModel from .builder import MODELS @@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel): pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf """ - def __init__(self, opt): """Initialize the pix2pix class. @@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel): if self.isTrain: self.netD = build_discriminator(opt.model.discriminator) - if self.isTrain: # define loss functions self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionL1 = paddle.nn.L1Loss() # build optimizers - self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters()) - self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters()) + self.build_lr_scheduler() + self.optimizer_G = build_optimizer( + opt.optimizer, + self.lr_scheduler, + parameter_list=self.netG.parameters()) + self.optimizer_D = build_optimizer( + opt.optimizer, + self.lr_scheduler, + parameter_list=self.netD.parameters()) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) @@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel): self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - def forward(self): """Run forward pass; called by both functions and .""" @@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel): def forward_test(self, input): input = paddle.imperative.to_variable(input) return self.netG(input) - + def backward_D(self): """Calculate GAN loss for the discriminator""" # Fake; stop backprop to the generator by detaching fake_B @@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel): pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B - self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + self.loss_G_L1 = self.criterionL1(self.fake_B, + self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 @@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel): # update D self.set_requires_grad(self.netD, True) - self.optimizer_D.clear_gradients() + self.optimizer_D.clear_gradients() self.backward_D() - self.optimizer_D.minimize(self.loss_D) - + self.optimizer_D.minimize(self.loss_D) + # update G - self.set_requires_grad(self.netD, False) + self.set_requires_grad(self.netD, False) self.optimizer_G.clear_gradients() self.backward_G() self.optimizer_G.minimize(self.loss_G) diff --git a/ppgan/solver/lr_scheduler.py b/ppgan/solver/lr_scheduler.py index 9bcbd7af78c368c8b1e7c59805890c395c8e2e49..3c17e3da0f06848ec446a5fbf141762eeae0b918 100644 --- a/ppgan/solver/lr_scheduler.py +++ b/ppgan/solver/lr_scheduler.py @@ -6,13 +6,23 @@ def build_lr_scheduler(cfg): # TODO: add more learning rate scheduler if name == 'linear': - return LinearDecay(**cfg) + + def lambda_rule(epoch): + lr_l = 1.0 - max( + 0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1) + return lr_l + + scheduler = paddle.optimizer.lr_scheduler.LambdaLR( + cfg.learning_rate, lr_lambda=lambda_rule) + return scheduler else: raise NotImplementedError -class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay): - def __init__(self, learning_rate, step_per_epoch, start_epoch, decay_epochs): +# paddle.optimizer.lr_scheduler +class LinearDecay(paddle.optimizer.lr_scheduler._LRScheduler): + def __init__(self, learning_rate, step_per_epoch, start_epoch, + decay_epochs): super(LinearDecay, self).__init__() self.learning_rate = learning_rate self.start_epoch = start_epoch @@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay def step(self): cur_epoch = int(self.step_num // self.step_per_epoch) - decay_rate = 1.0 - max(0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1) - return self.create_lr_var(decay_rate * self.learning_rate) \ No newline at end of file + decay_rate = 1.0 - max( + 0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1) + return self.create_lr_var(decay_rate * self.learning_rate) diff --git a/ppgan/solver/optimizer.py b/ppgan/solver/optimizer.py index 3f96c08efc75828db96e97aa5d1374b32b5264de..810cada4020f299f8611d2ddc7eb1d38c4bf7b50 100644 --- a/ppgan/solver/optimizer.py +++ b/ppgan/solver/optimizer.py @@ -4,13 +4,11 @@ import paddle from .lr_scheduler import build_lr_scheduler -def build_optimizer(cfg, parameter_list=None): +def build_optimizer(cfg, lr_scheduler, parameter_list=None): cfg_copy = copy.deepcopy(cfg) - - lr_scheduler_cfg = cfg_copy.pop('lr_scheduler', None) - - lr_scheduler = build_lr_scheduler(lr_scheduler_cfg) opt_name = cfg_copy.pop('name') - return getattr(paddle.optimizer, opt_name)(lr_scheduler, parameters=parameter_list, **cfg_copy) + return getattr(paddle.optimizer, opt_name)(lr_scheduler, + parameters=parameter_list, + **cfg_copy) diff --git a/ppgan/utils/logger.py b/ppgan/utils/logger.py index f65b3350df63e9a18038de55f8c952055c70d7fe..58d94f69e6a11be9dee35fbb16afbc298ae7e359 100644 --- a/ppgan/utils/logger.py +++ b/ppgan/utils/logger.py @@ -2,7 +2,7 @@ import logging import os import sys -from paddle import ParallelEnv +from paddle.distributed import ParallelEnv def setup_logger(output=None, name="ppgan"): @@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"): logger.propagate = False plain_formatter = logging.Formatter( - "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" - ) + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", + datefmt="%m/%d %H:%M:%S") # stdout logging: master only local_rank = ParallelEnv().local_rank if local_rank == 0: @@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"): fh.setFormatter(plain_formatter) logger.addHandler(fh) - return logger \ No newline at end of file + return logger diff --git a/ppgan/utils/setup.py b/ppgan/utils/setup.py index d56bf364788a0e8a756abd52a62378e68c7c1a11..f663ba960e21c1c1647d7acfdcad7af43d638367 100644 --- a/ppgan/utils/setup.py +++ b/ppgan/utils/setup.py @@ -2,7 +2,7 @@ import os import time import paddle -from paddle import ParallelEnv +from paddle.distributed import ParallelEnv from .logger import setup_logger @@ -12,7 +12,8 @@ def setup(args, cfg): cfg.isTrain = False cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) - cfg.output_dir = os.path.join(cfg.output_dir, str(cfg.model.name) + cfg.timestamp) + cfg.output_dir = os.path.join(cfg.output_dir, + str(cfg.model.name) + cfg.timestamp) logger = setup_logger(cfg.output_dir)