提交 abd3250d 编写于 作者: L LielinJiang

adapt to api 2.0 again

上级 4a3ba224
......@@ -11,7 +11,7 @@ from imageio import imread, imsave
import cv2
import paddle.fluid as fluid
from paddle.incubate.hapi.download import get_path_from_url
from paddle.utils.download import get_path_from_url
import networks
from util import *
......@@ -19,6 +19,7 @@ from my_args import parser
DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
def infer_engine(model_dir,
run_mode='fluid',
batch_size=1,
......@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
self.exe, self.program, self.fetch_targets = executor(model_path,
use_gpu=use_gpu)
def run(self):
frame_path_input = os.path.join(self.output_path, 'frames-input')
frame_path_interpolated = os.path.join(self.output_path,
......@@ -272,7 +272,7 @@ class VideoFrameInterp(object):
os.remove(video_pattern_output)
frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output,
r2)
return frame_pattern_combined, video_pattern_output
......
......@@ -15,15 +15,19 @@ from PIL import Image
from tqdm import tqdm
from paddle import fluid
from model import build_model
from paddle.incubate.hapi.download import get_path_from_url
from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path', type=str, default='none', help='Path to the reference image directory')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path',
type=str,
default='none',
help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
cmd = ffmpeg + [
......@@ -56,9 +60,9 @@ class DeOldifyPredictor():
def norm(self, img, render_factor=32, render_base=16):
target_size = render_factor * render_base
img = img.resize((target_size, target_size), resample=Image.BILINEAR)
img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
......@@ -69,13 +73,13 @@ class DeOldifyPredictor():
def denorm(self, img):
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
img *= img_std
img += img_mean
img = img.transpose((1, 2, 0))
return (img * 255).clip(0, 255).astype('uint8')
def post_process(self, raw_color, orig):
color_np = np.asarray(raw_color)
orig_np = np.asarray(orig)
......@@ -86,11 +90,11 @@ class DeOldifyPredictor():
final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
final = Image.fromarray(final)
return final
def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis,...])
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
......@@ -118,20 +122,20 @@ class DeOldifyPredictor():
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames):
pred_img = self.run_single(frame)
frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path, '{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps)))
return frame_pattern_combined, vid_out_path
vid_out_path = os.path.join(output_path,
'{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
......@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ',
ss,
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
......@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
return out_full_path
if __name__=='__main__':
if __name__ == '__main__':
paddle.enable_imperative()
args = parser.parse_args()
predictor = DeOldifyPredictor(args.input, args.output, weight_path=args.weight_path)
predictor = DeOldifyPredictor(args.input,
args.output,
weight_path=args.weight_path)
frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path)
\ No newline at end of file
print('output video path:', temp_video_path)
......@@ -15,195 +15,235 @@ import argparse
import subprocess
import utils
from remasternet import NetworkR, NetworkC
from paddle.incubate.hapi.download import get_path_from_url
from paddle.utils.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster without colorization')
parser.add_argument('--mindim', type=int, default='360', help='Length of minimum image edges')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--reference_dir',
type=str,
default=None,
help='Path to the reference image directory')
parser.add_argument('--colorization',
action='store_true',
default=False,
help='Remaster without colorization')
parser.add_argument('--mindim',
type=int,
default='360',
help='Length of minimum image edges')
class DeepReasterPredictor:
def __init__(self, input, output, weight_path=None, colorization=False, reference_dir=None, mindim=360):
self.input = input
self.output = os.path.join(output, 'DeepRemaster')
self.colorization = colorization
self.reference_dir = reference_dir
self.mindim = mindim
if weight_path is None:
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path)
state_dict, _ = paddle.load(weight_path)
self.modelR = NetworkR()
self.modelR.load_dict(state_dict['modelR'])
self.modelR.eval()
if colorization:
self.modelC = NetworkC()
self.modelC.load_dict(state_dict['modelC'])
self.modelC.eval()
def run(self):
outputdir = self.output
outputdir_in = os.path.join(outputdir, 'input/')
os.makedirs( outputdir_in, exist_ok=True )
outputdir_out = os.path.join(outputdir, 'output/')
os.makedirs( outputdir_out, exist_ok=True )
# Prepare reference images
if self.colorization:
if self.reference_dir is not None:
import glob
ext_list = ['png','jpg','bmp']
reference_files = []
for ext in ext_list:
reference_files += glob.glob( self.reference_dir+'/*.'+ext, recursive=True )
aspect_mean = 0
minedge_dim = 256
refs = []
for v in reference_files:
refimg = Image.open( v ).convert('RGB')
w, h = refimg.size
aspect_mean += w/h
refs.append( refimg )
aspect_mean /= len(reference_files)
target_w = int(256*aspect_mean) if aspect_mean>1 else 256
target_h = 256 if aspect_mean>=1 else int(256/aspect_mean)
refimgs = []
for i, v in enumerate(refs):
refimg = utils.addMergin( v, target_w=target_w, target_h=target_h )
refimg = np.array(refimg).astype('float32').transpose(2, 0, 1) / 255.0
refimgs.append(refimg)
refimgs = paddle.to_tensor(np.array(refimgs).astype('float32'))
refimgs = paddle.unsqueeze(refimgs, 0)
# Load video
cap = cv2.VideoCapture( self.input )
nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
minwh = min(v_w,v_h)
scale = 1
if minwh != self.mindim:
scale = self.mindim / minwh
t_w = round(v_w*scale/16.)*16
t_h = round(v_h*scale/16.)*16
fps = cap.get(cv2.CAP_PROP_FPS)
pbar = tqdm(total=nframes)
block = 5
# Process
with paddle.no_grad():
it = 0
while True:
frame_pos = it*block
if frame_pos >= nframes:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
if block >= nframes-frame_pos:
proc_g = nframes-frame_pos
else:
proc_g = block
input = None
gtC = None
for i in range(proc_g):
index = frame_pos + i
_, frame = cap.read()
frame = cv2.resize(frame, (t_w, t_h))
nchannels = frame.shape[2]
if nchannels == 1 or self.colorization:
frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cv2.imwrite(outputdir_in+'%07d.png'%index, frame_l)
frame_l = paddle.to_tensor(frame_l.astype('float32'))
frame_l = paddle.reshape(frame_l, [frame_l.shape[0], frame_l.shape[1], 1])
frame_l = paddle.transpose(frame_l, [2, 0, 1])
frame_l /= 255.
frame_l = paddle.reshape(frame_l, [1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]])
elif nchannels == 3:
cv2.imwrite(outputdir_in+'%07d.png'%index, frame)
frame = frame[:,:,::-1] ## BGR -> RGB
frame_l, frame_ab = utils.convertRGB2LABTensor( frame )
frame_l = frame_l.transpose([2, 0, 1])
frame_ab = frame_ab.transpose([2, 0, 1])
frame_l = frame_l.reshape([1, frame_l.shape[0], 1, frame_l.shape[1], frame_l.shape[2]])
frame_ab = frame_ab.reshape([1, frame_ab.shape[0], 1, frame_ab.shape[1], frame_ab.shape[2]])
if input is not None:
paddle.concat( (input, frame_l), 2 )
input = frame_l if i==0 else paddle.concat( (input, frame_l), 2 )
if nchannels==3 and not self.colorization:
gtC = frame_ab if i==0 else paddle.concat( (gtC, frame_ab), 2 )
input = paddle.to_tensor(input)
output_l = self.modelR( input ) # [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization]
if not self.colorization:
for i in range( proc_g ):
index = frame_pos + i
if nchannels==3:
out_l = output_l.detach()[0,:,i]
out_ab = gtC[0,:,i]
out = paddle.concat((out_l, out_ab),axis=0).detach().numpy().transpose((1, 2, 0))
out = Image.fromarray( np.uint8( utils.convertLAB2RGB( out )*255 ) )
out.save( outputdir_out+'%07d.png'%(index) )
else:
raise ValueError('channels of imag3 must be 3!')
# Perform colorization
else:
if self.reference_dir is None:
output_ab = self.modelC( output_l )
else:
output_ab = self.modelC( output_l, refimgs )
output_l = output_l.detach()
output_ab = output_ab.detach()
for i in range( proc_g ):
index = frame_pos + i
out_l = output_l[0,:,i,:,:]
out_c = output_ab[0,:,i,:,:]
output = paddle.concat((out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
output.save( outputdir_out+'%07d.png'%index )
it = it + 1
pbar.update(proc_g)
# Save result videos
outfile = os.path.join(outputdir, self.input.split('/')[-1].split('.')[0])
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (fps, outputdir_in, fps, outfile )
subprocess.call( cmd, shell=True )
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (fps, outputdir_out, fps, outfile )
subprocess.call( cmd, shell=True )
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % ( outfile, outfile, outfile )
subprocess.call( cmd, shell=True )
cap.release()
pbar.close()
return outputdir_out, '%s_out.mp4' % outfile
def __init__(self,
input,
output,
weight_path=None,
colorization=False,
reference_dir=None,
mindim=360):
self.input = input
self.output = os.path.join(output, 'DeepRemaster')
self.colorization = colorization
self.reference_dir = reference_dir
self.mindim = mindim
if weight_path is None:
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path)
state_dict, _ = paddle.load(weight_path)
self.modelR = NetworkR()
self.modelR.load_dict(state_dict['modelR'])
self.modelR.eval()
if colorization:
self.modelC = NetworkC()
self.modelC.load_dict(state_dict['modelC'])
self.modelC.eval()
def run(self):
outputdir = self.output
outputdir_in = os.path.join(outputdir, 'input/')
os.makedirs(outputdir_in, exist_ok=True)
outputdir_out = os.path.join(outputdir, 'output/')
os.makedirs(outputdir_out, exist_ok=True)
# Prepare reference images
if self.colorization:
if self.reference_dir is not None:
import glob
ext_list = ['png', 'jpg', 'bmp']
reference_files = []
for ext in ext_list:
reference_files += glob.glob(self.reference_dir + '/*.' +
ext,
recursive=True)
aspect_mean = 0
minedge_dim = 256
refs = []
for v in reference_files:
refimg = Image.open(v).convert('RGB')
w, h = refimg.size
aspect_mean += w / h
refs.append(refimg)
aspect_mean /= len(reference_files)
target_w = int(256 * aspect_mean) if aspect_mean > 1 else 256
target_h = 256 if aspect_mean >= 1 else int(256 / aspect_mean)
refimgs = []
for i, v in enumerate(refs):
refimg = utils.addMergin(v,
target_w=target_w,
target_h=target_h)
refimg = np.array(refimg).astype('float32').transpose(
2, 0, 1) / 255.0
refimgs.append(refimg)
refimgs = paddle.to_tensor(np.array(refimgs).astype('float32'))
refimgs = paddle.unsqueeze(refimgs, 0)
# Load video
cap = cv2.VideoCapture(self.input)
nframes = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
minwh = min(v_w, v_h)
scale = 1
if minwh != self.mindim:
scale = self.mindim / minwh
t_w = round(v_w * scale / 16.) * 16
t_h = round(v_h * scale / 16.) * 16
fps = cap.get(cv2.CAP_PROP_FPS)
pbar = tqdm(total=nframes)
block = 5
# Process
with paddle.no_grad():
it = 0
while True:
frame_pos = it * block
if frame_pos >= nframes:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
if block >= nframes - frame_pos:
proc_g = nframes - frame_pos
else:
proc_g = block
input = None
gtC = None
for i in range(proc_g):
index = frame_pos + i
_, frame = cap.read()
frame = cv2.resize(frame, (t_w, t_h))
nchannels = frame.shape[2]
if nchannels == 1 or self.colorization:
frame_l = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cv2.imwrite(outputdir_in + '%07d.png' % index, frame_l)
frame_l = paddle.to_tensor(frame_l.astype('float32'))
frame_l = paddle.reshape(
frame_l, [frame_l.shape[0], frame_l.shape[1], 1])
frame_l = paddle.transpose(frame_l, [2, 0, 1])
frame_l /= 255.
frame_l = paddle.reshape(frame_l, [
1, frame_l.shape[0], 1, frame_l.shape[1],
frame_l.shape[2]
])
elif nchannels == 3:
cv2.imwrite(outputdir_in + '%07d.png' % index, frame)
frame = frame[:, :, ::-1] ## BGR -> RGB
frame_l, frame_ab = utils.convertRGB2LABTensor(frame)
frame_l = frame_l.transpose([2, 0, 1])
frame_ab = frame_ab.transpose([2, 0, 1])
frame_l = frame_l.reshape([
1, frame_l.shape[0], 1, frame_l.shape[1],
frame_l.shape[2]
])
frame_ab = frame_ab.reshape([
1, frame_ab.shape[0], 1, frame_ab.shape[1],
frame_ab.shape[2]
])
if input is not None:
paddle.concat((input, frame_l), 2)
input = frame_l if i == 0 else paddle.concat(
(input, frame_l), 2)
if nchannels == 3 and not self.colorization:
gtC = frame_ab if i == 0 else paddle.concat(
(gtC, frame_ab), 2)
input = paddle.to_tensor(input)
output_l = self.modelR(input) # [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization]
if not self.colorization:
for i in range(proc_g):
index = frame_pos + i
if nchannels == 3:
out_l = output_l.detach()[0, :, i]
out_ab = gtC[0, :, i]
out = paddle.concat(
(out_l, out_ab),
axis=0).detach().numpy().transpose((1, 2, 0))
out = Image.fromarray(
np.uint8(utils.convertLAB2RGB(out) * 255))
out.save(outputdir_out + '%07d.png' % (index))
else:
raise ValueError('channels of imag3 must be 3!')
# Perform colorization
else:
if self.reference_dir is None:
output_ab = self.modelC(output_l)
else:
output_ab = self.modelC(output_l, refimgs)
output_l = output_l.detach()
output_ab = output_ab.detach()
for i in range(proc_g):
index = frame_pos + i
out_l = output_l[0, :, i, :, :]
out_c = output_ab[0, :, i, :, :]
output = paddle.concat(
(out_l, out_c), axis=0).numpy().transpose((1, 2, 0))
output = Image.fromarray(
np.uint8(utils.convertLAB2RGB(output) * 255))
output.save(outputdir_out + '%07d.png' % index)
it = it + 1
pbar.update(proc_g)
# Save result videos
outfile = os.path.join(outputdir,
self.input.split('/')[-1].split('.')[0])
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4' % (
fps, outputdir_in, fps, outfile)
subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4' % (
fps, outputdir_out, fps, outfile)
subprocess.call(cmd, shell=True)
cmd = 'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4' % (
outfile, outfile, outfile)
subprocess.call(cmd, shell=True)
cap.release()
pbar.close()
return outputdir_out, '%s_out.mp4' % outfile
if __name__ == "__main__":
args = parser.parse_args()
paddle.disable_static()
predictor = DeepReasterPredictor(args.input, args.output, colorization=args.colorization,
reference_dir=args.reference_dir, mindim=args.mindim)
predictor.run()
\ No newline at end of file
args = parser.parse_args()
paddle.disable_static()
predictor = DeepReasterPredictor(args.input,
args.output,
colorization=args.colorization,
reference_dir=args.reference_dir,
mindim=args.mindim)
predictor.run()
......@@ -28,30 +28,29 @@ import paddle.fluid as fluid
import cv2
from data import EDVRDataset
from paddle.incubate.hapi.download import get_path_from_url
from paddle.utils.download import get_path_from_url
EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
type=str,
default=None,
help='input video path')
parser.add_argument(
'--output',
type=str,
default='output',
help='output path')
parser.add_argument(
'--weight_path',
type=str,
default=None,
help='weight path')
parser.add_argument('--input',
type=str,
default=None,
help='input video path')
parser.add_argument('--output',
type=str,
default='output',
help='output path')
parser.add_argument('--weight_path',
type=str,
default=None,
help='weight path')
args = parser.parse_args()
return args
def get_img(pred):
print('pred shape', pred.shape)
pred = pred.squeeze()
......@@ -59,10 +58,11 @@ def get_img(pred):
pred = pred * 255
pred = pred.round()
pred = pred.astype('uint8')
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
pred = pred[:, :, ::-1] # rgb -> bgr
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
pred = pred[:, :, ::-1] # rgb -> bgr
return pred
def save_img(img, framename):
dirname = os.path.dirname(framename)
if not os.path.exists(dirname):
......@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ',
ss,
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
outformat
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
......@@ -134,20 +123,21 @@ class EDVRPredictor:
self.input = input
self.output = os.path.join(output, 'EDVR')
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
self.exe = fluid.Executor(place)
if weight_path is None:
weight_path = get_path_from_url(EDVR_weight_url, cur_path)
print(weight_path)
model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams'
out = fluid.io.load_inference_model(dirname=weight_path,
model_filename=model_filename,
params_filename=params_filename,
out = fluid.io.load_inference_model(dirname=weight_path,
model_filename=model_filename,
params_filename=params_filename,
executor=self.exe)
self.infer_prog, self.feed_list, self.fetch_list = out
......@@ -176,16 +166,19 @@ class EDVRPredictor:
cur_time = time.time()
for infer_iter, data in enumerate(dataset):
data_feed_in = [data[0]]
infer_outs = self.exe.run(self.infer_prog,
fetch_list=self.fetch_list,
feed={self.feed_list[0]:np.array(data_feed_in)})
infer_outs = self.exe.run(
self.infer_prog,
fetch_list=self.fetch_list,
feed={self.feed_list[0]: np.array(data_feed_in)})
infer_result_list = [item for item in infer_outs]
frame_path = data[1]
img_i = get_img(infer_result_list[0])
save_img(img_i, os.path.join(pred_frame_path, os.path.basename(frame_path)))
save_img(
img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path)))
prev_time = cur_time
cur_time = time.time()
......@@ -194,13 +187,15 @@ class EDVRPredictor:
print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, str(int(fps)))
vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path
if __name__ == "__main__":
args = parse_args()
predictor = EDVRPredictor(args.input, args.output, args.weight_path)
predictor.run()
......@@ -60,11 +60,12 @@ dataset:
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
......@@ -72,4 +73,3 @@ log_config:
snapshot_config:
interval: 5
......@@ -59,11 +59,12 @@ dataset:
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
......@@ -71,4 +72,3 @@ log_config:
snapshot_config:
interval: 5
......@@ -25,7 +25,7 @@ dataset:
train:
name: PairedDataset
dataroot: data/cityscapes
num_workers: 0
num_workers: 4
phase: train
max_dataset_size: inf
direction: BtoA
......@@ -57,11 +57,12 @@ dataset:
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
......@@ -69,4 +70,3 @@ log_config:
snapshot_config:
interval: 5
......@@ -56,11 +56,12 @@ dataset:
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0004
start_epoch: 100
decay_epochs: 100
lr_scheduler:
name: linear
learning_rate: 0.0004
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
......@@ -68,4 +69,3 @@ log_config:
snapshot_config:
interval: 5
......@@ -56,11 +56,12 @@ dataset:
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
......
......@@ -6,7 +6,7 @@ from paddle.io import Dataset
from PIL import Image
import cv2
import paddle.incubate.hapi.vision.transforms as transforms
import paddle.vision.transforms as transforms
from .transforms import transforms as T
from abc import ABC, abstractmethod
......@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
class BaseDataset(Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
"""
def __init__(self, cfg):
"""Initialize the class; save the options in the class
......@@ -60,8 +59,11 @@ def get_params(cfg, size):
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True):
def get_transform(cfg,
params=None,
grayscale=False,
method=cv2.INTER_CUBIC,
convert=True):
transform_list = []
if grayscale:
print('grayscale not support for now!!!')
......@@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.RandomHorizontalFlip(1.0))
if convert:
transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))]
transform_list += [
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
]
return transforms.Compose(transform_list)
......@@ -3,12 +3,11 @@ import paddle
import numbers
import numpy as np
from multiprocessing import Manager
from paddle import ParallelEnv
from paddle.distributed import ParallelEnv
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry
DATASETS = Registry("DATASETS")
......@@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset):
single_item = dataset[0]
self.keys = single_item.keys()
for k, v in single_item.items():
if not isinstance(v, (numbers.Number, np.ndarray)):
setattr(self, k, Manager().dict())
......@@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset):
def __getitem__(self, index):
ori_map = self.dataset[index]
tmp_list = []
for k, v in ori_map.items():
if isinstance(v, (numbers.Number, np.ndarray)):
tmp_list.append(v)
......@@ -60,17 +59,15 @@ class DictDataLoader():
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)
sampler = DistributedBatchSampler(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
sampler = DistributedBatchSampler(self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
self.dataloader = paddle.io.DataLoader(self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
self.batch_size = batch_size
......@@ -83,7 +80,9 @@ class DictDataLoader():
j = 0
for k in self.dataset.keys:
if k in self.dataset.tensor_keys_set:
return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data
return_dict[k] = data[j] if isinstance(data,
(list,
tuple)) else data
j += 1
else:
return_dict[k] = self.get_items_by_indexs(k, data[-1])
......@@ -104,13 +103,12 @@ class DictDataLoader():
return current_items
def build_dataloader(cfg, is_train=True):
dataset = DATASETS.get(cfg.name)(cfg)
batch_size = cfg.get('batch_size', 1)
num_workers = cfg.get('num_workers', 0)
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
return dataloader
\ No newline at end of file
return dataloader
......@@ -4,7 +4,7 @@ import time
import logging
import paddle
from paddle import ParallelEnv, DataParallel
from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader
from ..models.builder import build_model
......@@ -17,10 +17,11 @@ class Trainer:
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
if 'lr_scheduler' in cfg.optimizer:
cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader)
cfg.optimizer.lr_scheduler.step_per_epoch = len(
self.train_dataloader)
# build model
self.model = build_model(cfg)
# multiple gpus prepare
......@@ -44,16 +45,17 @@ class Trainer:
# time count
self.time_count = {}
def distributed_data_parallel(self):
strategy = paddle.prepare_context()
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name, DataParallel(net, strategy))
setattr(self.model, 'net' + name,
paddle.DataParallel(net, strategy))
def train(self):
for epoch in range(self.start_epoch, self.epochs):
self.current_epoch = epoch
start_time = step_start_time = time.time()
......@@ -64,24 +66,27 @@ class Trainer:
# data input should be dict
self.model.set_input(data)
self.model.optimize_parameters()
self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time
if i % self.log_interval == 0:
self.print_log()
if i % self.visual_interval == 0:
self.visual('visual_train')
step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - start_time))
self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
self.save(epoch)
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False)
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
# data[0]: img, data[1]: img path index
# test batch size must be 1
......@@ -103,14 +108,15 @@ class Trainer:
visual_results.update({name: img_tensor[j]})
self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
def print_log(self):
losses = self.model.get_current_losses()
message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
message += '%s: %.6f ' % ('lr', self.current_learning_rate)
for k, v in losses.items():
......@@ -143,13 +149,14 @@ class Trainer:
makedirs(os.path.join(self.output_dir, results_dir))
for label, image in visual_results.items():
image_numpy = tensor2img(image)
img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label))
img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label))
save_image(image_numpy, img_path)
def save(self, epoch, name='checkpoint', keep=1):
if self.local_rank != 0:
return
assert name in ['checkpoint', 'weight']
state_dicts = {}
......@@ -175,8 +182,8 @@ class Trainer:
if keep > 0:
try:
checkpoint_name_to_be_removed = os.path.join(self.output_dir,
'epoch_%s_%s.pkl' % (epoch - keep, name))
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'epoch_%s_%s.pkl' % (epoch - keep, name))
if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed)
......@@ -187,7 +194,7 @@ class Trainer:
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
......@@ -200,9 +207,8 @@ class Trainer:
def load(self, weight_path):
state_dicts = load(weight_path)
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
\ No newline at end of file
......@@ -5,6 +5,7 @@ import numpy as np
from collections import OrderedDict
from abc import ABC, abstractmethod
from ..solver.lr_scheduler import build_lr_scheduler
class BaseModel(ABC):
......@@ -16,7 +17,6 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
......@@ -33,8 +33,10 @@ class BaseModel(ABC):
"""
self.opt = opt
self.isTrain = opt.isTrain
self.save_dir = os.path.join(opt.output_dir, opt.model.name) # save all the checkpoints to save_dir
self.save_dir = os.path.join(
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
self.loss_names = []
self.model_names = []
self.visual_names = []
......@@ -75,6 +77,8 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler)
def eval(self):
"""Make models eval mode during test time"""
......@@ -114,10 +118,11 @@ class BaseModel(ABC):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
......
import paddle
from paddle import ParallelEnv
from paddle.distributed import ParallelEnv
from .base_model import BaseModel
from .builder import MODELS
......@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
def __init__(self, opt):
"""Initialize the CycleGAN class.
......@@ -32,12 +31,14 @@ class CycleGANModel(BaseModel):
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
self.loss_names = [
'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
......@@ -62,18 +63,28 @@ class CycleGANModel(BaseModel):
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.dataset.train.input_nc == opt.dataset.train.output_nc)
assert (
opt.dataset.train.input_nc == opt.dataset.train.output_nc)
# create image buffer to store previously generated images
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size)
# create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size)
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss()
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG_A.parameters() + self.netG_B.parameters())
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD_A.parameters() + self.netD_B.parameters())
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG_A.parameters() +
self.netG_B.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters() +
self.netD_B.parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
......@@ -90,7 +101,7 @@ class CycleGANModel(BaseModel):
"""
mode = 'train' if self.isTrain else 'test'
AtoB = self.opt.dataset[mode].direction == 'AtoB'
if AtoB:
if 'A' in input:
self.real_A = paddle.to_tensor(input['A'])
......@@ -107,17 +118,15 @@ class CycleGANModel(BaseModel):
elif 'B_paths' in input:
self.image_paths = input['B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'):
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
if hasattr(self, 'real_B'):
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
......@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
......@@ -179,12 +190,14 @@ class CycleGANModel(BaseModel):
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if ParallelEnv().nranks > 1:
self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward()
......@@ -216,6 +229,5 @@ class CycleGANModel(BaseModel):
self.backward_D_A()
# calculate graidents for D_B
self.backward_D_B()
# update D_A and D_B's weights
# update D_A and D_B's weights
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B)
import paddle
from paddle import ParallelEnv
from paddle.distributed import ParallelEnv
from .base_model import BaseModel
from .builder import MODELS
......@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
"""
def __init__(self, opt):
"""Initialize the pix2pix class.
......@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator)
if self.isTrain:
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss()
# build optimizers
self.optimizer_G = build_optimizer(opt.optimizer, parameter_list=self.netG.parameters())
self.optimizer_D = build_optimizer(opt.optimizer, parameter_list=self.netD.parameters())
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD.parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
......@@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel):
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
......@@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel):
def forward_test(self, input):
input = paddle.imperative.to_variable(input)
return self.netG(input)
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
......@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
self.loss_G_L1 = self.criterionL1(self.fake_B,
self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
......@@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel):
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.clear_gradients()
self.optimizer_D.clear_gradients()
self.backward_D()
self.optimizer_D.minimize(self.loss_D)
self.optimizer_D.minimize(self.loss_D)
# update G
self.set_requires_grad(self.netD, False)
self.set_requires_grad(self.netD, False)
self.optimizer_G.clear_gradients()
self.backward_G()
self.optimizer_G.minimize(self.loss_G)
......@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
# TODO: add more learning rate scheduler
if name == 'linear':
return LinearDecay(**cfg)
def lambda_rule(epoch):
lr_l = 1.0 - max(
0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1)
return lr_l
scheduler = paddle.optimizer.lr_scheduler.LambdaLR(
cfg.learning_rate, lr_lambda=lambda_rule)
return scheduler
else:
raise NotImplementedError
class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay):
def __init__(self, learning_rate, step_per_epoch, start_epoch, decay_epochs):
# paddle.optimizer.lr_scheduler
class LinearDecay(paddle.optimizer.lr_scheduler._LRScheduler):
def __init__(self, learning_rate, step_per_epoch, start_epoch,
decay_epochs):
super(LinearDecay, self).__init__()
self.learning_rate = learning_rate
self.start_epoch = start_epoch
......@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
def step(self):
cur_epoch = int(self.step_num // self.step_per_epoch)
decay_rate = 1.0 - max(0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1)
return self.create_lr_var(decay_rate * self.learning_rate)
\ No newline at end of file
decay_rate = 1.0 - max(
0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1)
return self.create_lr_var(decay_rate * self.learning_rate)
......@@ -4,13 +4,11 @@ import paddle
from .lr_scheduler import build_lr_scheduler
def build_optimizer(cfg, parameter_list=None):
def build_optimizer(cfg, lr_scheduler, parameter_list=None):
cfg_copy = copy.deepcopy(cfg)
lr_scheduler_cfg = cfg_copy.pop('lr_scheduler', None)
lr_scheduler = build_lr_scheduler(lr_scheduler_cfg)
opt_name = cfg_copy.pop('name')
return getattr(paddle.optimizer, opt_name)(lr_scheduler, parameters=parameter_list, **cfg_copy)
return getattr(paddle.optimizer, opt_name)(lr_scheduler,
parameters=parameter_list,
**cfg_copy)
......@@ -2,7 +2,7 @@ import logging
import os
import sys
from paddle import ParallelEnv
from paddle.distributed import ParallelEnv
def setup_logger(output=None, name="ppgan"):
......@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
logger.propagate = False
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
local_rank = ParallelEnv().local_rank
if local_rank == 0:
......@@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"):
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
\ No newline at end of file
return logger
......@@ -2,7 +2,7 @@ import os
import time
import paddle
from paddle import ParallelEnv
from paddle.distributed import ParallelEnv
from .logger import setup_logger
......@@ -12,7 +12,8 @@ def setup(args, cfg):
cfg.isTrain = False
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, str(cfg.model.name) + cfg.timestamp)
cfg.output_dir = os.path.join(cfg.output_dir,
str(cfg.model.name) + cfg.timestamp)
logger = setup_logger(cfg.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册