未验证 提交 8a4848dc 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #23 from LielinJiang/sr

Move some model to ppgan, Add sr model
...@@ -13,8 +13,8 @@ import cv2 ...@@ -13,8 +13,8 @@ import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from ppgan.utils.video import video2frames, frames2video
import networks
from util import * from util import *
from my_args import parser from my_args import parser
...@@ -129,7 +129,7 @@ class VideoFrameInterp(object): ...@@ -129,7 +129,7 @@ class VideoFrameInterp(object):
r2 = str(int(fps) * times_interp) r2 = str(int(fps) * times_interp)
print("New fps (frame rate): ", r2) print("New fps (frame rate): ", r2)
out_path = dump_frames_ffmpeg(vid, frame_path_input) out_path = video2frames(vid, frame_path_input)
vidname = vid.split('/')[-1].split('.')[0] vidname = vid.split('/')[-1].split('.')[0]
...@@ -266,7 +266,7 @@ class VideoFrameInterp(object): ...@@ -266,7 +266,7 @@ class VideoFrameInterp(object):
vidname + '.mp4') vidname + '.mp4')
if os.path.exists(video_pattern_output): if os.path.exists(video_pattern_output):
os.remove(video_pattern_output) os.remove(video_pattern_output)
frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output, frames2video(frame_pattern_combined, video_pattern_output,
r2) r2)
return frame_pattern_combined, video_pattern_output return frame_pattern_combined, video_pattern_output
......
...@@ -21,66 +21,6 @@ class AverageMeter(object): ...@@ -21,66 +21,6 @@ class AverageMeter(object):
self.avg = self.sum / self.count self.avg = self.sum / self.count
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ',
ss,
' -t ',
t,
' -i ',
vid_path,
' -r ',
r,
# ' -f ', ' image2 ',
# ' -s ', ' 960*540 ',
' -qscale:v ',
' 0.1 ',
' -start_number ',
' 0 ',
# ' -qmax ', ' 1 ',
outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
def combine_frames(input, interpolated, combined, num_frames): def combine_frames(input, interpolated, combined, num_frames):
frames1 = sorted(glob.glob(os.path.join(input, '*.png'))) frames1 = sorted(glob.glob(os.path.join(input, '*.png')))
frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png'))) frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png')))
......
...@@ -14,8 +14,9 @@ import pickle ...@@ -14,8 +14,9 @@ import pickle
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from paddle import fluid from paddle import fluid
from model import build_model
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model
parser = argparse.ArgumentParser(description='DeOldify') parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--input', type=str, default='none', help='Input video')
...@@ -29,23 +30,7 @@ parser.add_argument('--weight_path', ...@@ -29,23 +30,7 @@ parser.add_argument('--weight_path',
default=None, default=None,
help='Path to the reference image directory') help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class DeOldifyPredictor(): class DeOldifyPredictor():
...@@ -60,7 +45,7 @@ class DeOldifyPredictor(): ...@@ -60,7 +45,7 @@ class DeOldifyPredictor():
self.render_factor = render_factor self.render_factor = render_factor
self.model = build_model() self.model = build_model()
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DeOldify_weight_url, cur_path) weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
...@@ -127,7 +112,7 @@ class DeOldifyPredictor(): ...@@ -127,7 +112,7 @@ class DeOldifyPredictor():
cap = cv2.VideoCapture(vid) cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
out_path = dump_frames_ffmpeg(vid, output_path) out_path = video2frames(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
...@@ -141,42 +126,11 @@ class DeOldifyPredictor(): ...@@ -141,42 +126,11 @@ class DeOldifyPredictor():
vid_out_path = os.path.join(output_path, vid_out_path = os.path.join(output_path,
'{}_deoldify_out.mp4'.format(base_name)) '{}_deoldify_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
if __name__ == '__main__': if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
args = parser.parse_args() args = parser.parse_args()
......
import numpy as np
from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid import layers as F
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype
import paddle
import paddle.nn as nn
class _SpectralNorm(nn.SpectralNorm):
def __init__(self,
weight_shape,
dim=0,
power_iters=1,
eps=1e-12,
dtype='float32'):
super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype)
def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
_power_iters = self._power_iters if self.training else 0
self._helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={"Out": out, },
attrs={
"dim": self._dim,
"power_iters": _power_iters, #self._power_iters,
"eps": self._eps,
})
return out
class Spectralnorm(nn.Layer):
def __init__(self,
layer,
dim=0,
power_iters=1,
eps=1e-12,
dtype='float32'):
super(Spectralnorm, self).__init__()
self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype)
self.dim = dim
self.power_iters = power_iters
self.eps = eps
self.layer = layer
weight = layer._parameters['weight']
del layer._parameters['weight']
self.weight_orig = self.create_parameter(weight.shape, dtype=weight.dtype)
self.weight_orig.set_value(weight)
def forward(self, x):
weight = self.spectral_norm(self.weight_orig)
self.layer.weight = weight
out = self.layer(x)
return out
...@@ -14,10 +14,10 @@ from tqdm import tqdm ...@@ -14,10 +14,10 @@ from tqdm import tqdm
import argparse import argparse
import subprocess import subprocess
import utils import utils
from remasternet import NetworkR, NetworkC from ppgan.models.generators.remaster import NetworkR, NetworkC
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering') parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video') parser.add_argument('--input', type=str, default=None, help='Input video')
...@@ -51,7 +51,7 @@ class DeepReasterPredictor: ...@@ -51,7 +51,7 @@ class DeepReasterPredictor:
self.mindim = mindim self.mindim = mindim
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
......
...@@ -30,8 +30,9 @@ import cv2 ...@@ -30,8 +30,9 @@ import cv2
from tqdm import tqdm from tqdm import tqdm
from data import EDVRDataset from data import EDVRDataset
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def parse_args(): def parse_args():
...@@ -71,52 +72,6 @@ def save_img(img, framename): ...@@ -71,52 +72,6 @@ def save_img(img, framename):
cv2.imwrite(framename, img) cv2.imwrite(framename, img)
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class EDVRPredictor: class EDVRPredictor:
def __init__(self, input, output, weight_path=None): def __init__(self, input, output, weight_path=None):
self.input = input self.input = input
...@@ -127,9 +82,7 @@ class EDVRPredictor: ...@@ -127,9 +82,7 @@ class EDVRPredictor:
self.exe = fluid.Executor(place) self.exe = fluid.Executor(place)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(EDVR_weight_url, cur_path) weight_path = get_path_from_url(EDVR_WEIGHT_URL, cur_path)
print(weight_path)
model_filename = 'EDVR_model.pdmodel' model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams' params_filename = 'EDVR_params.pdparams'
...@@ -155,7 +108,7 @@ class EDVRPredictor: ...@@ -155,7 +108,7 @@ class EDVRPredictor:
cap = cv2.VideoCapture(vid) cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
out_path = dump_frames_ffmpeg(vid, output_path) out_path = video2frames(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
...@@ -188,8 +141,7 @@ class EDVRPredictor: ...@@ -188,8 +141,7 @@ class EDVRPredictor:
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name)) '{}_edvr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
......
...@@ -13,7 +13,9 @@ import pickle ...@@ -13,7 +13,9 @@ import pickle
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from sr_model import RRDBNet
from ppgan.models.generators import RRDBNet
from ppgan.utils.video import frames2video, video2frames
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='RealSR') parser = argparse.ArgumentParser(description='RealSR')
...@@ -24,23 +26,7 @@ parser.add_argument('--weight_path', ...@@ -24,23 +26,7 @@ parser.add_argument('--weight_path',
default=None, default=None,
help='Path to the reference image directory') help='Path to the reference image directory')
RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class RealSRPredictor(): class RealSRPredictor():
...@@ -49,7 +35,7 @@ class RealSRPredictor(): ...@@ -49,7 +35,7 @@ class RealSRPredictor():
self.output = os.path.join(output, 'RealSR') self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23) self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(RealSR_weight_url, cur_path) weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
...@@ -88,7 +74,7 @@ class RealSRPredictor(): ...@@ -88,7 +74,7 @@ class RealSRPredictor():
cap = cv2.VideoCapture(vid) cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
out_path = dump_frames_ffmpeg(vid, output_path) out_path = video2frames(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
...@@ -102,42 +88,11 @@ class RealSRPredictor(): ...@@ -102,42 +88,11 @@ class RealSRPredictor():
vid_out_path = os.path.join(output_path, vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name)) '{}_realsr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
if __name__ == '__main__': if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
args = parser.parse_args() args = parser.parse_args()
......
...@@ -41,6 +41,11 @@ dataset: ...@@ -41,6 +41,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/cityscapes/testB dataroot: data/cityscapes/testB
...@@ -55,6 +60,11 @@ dataset: ...@@ -55,6 +60,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: True no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
......
...@@ -40,6 +40,11 @@ dataset: ...@@ -40,6 +40,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/horse2zebra/testA dataroot: data/horse2zebra/testA
...@@ -54,7 +59,11 @@ dataset: ...@@ -54,7 +59,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: True no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -38,6 +38,11 @@ dataset: ...@@ -38,6 +38,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -53,6 +58,11 @@ dataset: ...@@ -53,6 +58,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: True no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -37,6 +37,11 @@ dataset: ...@@ -37,6 +37,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -52,6 +57,11 @@ dataset: ...@@ -52,6 +57,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: True no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -37,6 +37,11 @@ dataset: ...@@ -37,6 +37,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: False no_flip: False
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/facades/ dataroot: data/facades/
...@@ -52,6 +57,11 @@ dataset: ...@@ -52,6 +57,11 @@ dataset:
crop_size: 256 crop_size: 256
preprocess: resize_and_crop preprocess: resize_and_crop
no_flip: True no_flip: True
normalize:
mean:
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
from .unpaired_dataset import UnpairedDataset from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
\ No newline at end of file
...@@ -94,7 +94,9 @@ def get_transform(cfg, ...@@ -94,7 +94,9 @@ def get_transform(cfg,
if convert: if convert:
transform_list += [transforms.Permute(to_rgb=True)] transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [ if cfg.get('normalize', None):
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) transform_list += [
] transforms.Normalize(cfg.normalize.mean, cfg.normalize.std)
]
return transforms.Compose(transform_list) return transforms.Compose(transform_list)
# import mmcv
import os
import cv2
import random
import numpy as np
import paddle.vision.transforms as transform
from pathlib import Path
from paddle.io import Dataset
from .builder import DATASETS
def scandir(dir_path, suffix=None, recursive=False):
"""Scan a directory to find the interested files.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = os.path.relpath(entry.path, root)
if suffix is None:
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
"""
assert len(folders) == 2, (
'The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, (
'The len of keys should be 2 with [input_key, gt_key]. '
f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = os.path.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in '
f'{input_key}_paths.')
gt_path = os.path.join(gt_folder, gt_path)
paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path)]))
return paths
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in img_lqs
]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
for v in img_gts
]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip:
cv2.flip(img, 1, img)
if vflip:
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip:
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip:
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
return imgs
@DATASETS.register()
class SRImageDataset(Dataset):
"""Paired image dataset for image restoration."""
def __init__(self, cfg):
super(SRImageDataset, self).__init__()
self.cfg = cfg
self.file_client = None
self.io_backend_opt = cfg['io_backend']
self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq']
if 'filename_tmpl' in cfg:
self.filename_tmpl = cfg['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
#TODO: LielinJiang support lmdb to accelerate io
pass
elif 'meta_info_file' in self.cfg and self.cfg[
'meta_info_file'] is not None:
#TODO: LielinJiang support lmdb to accelerate io
pass
else:
self.paths = paired_paths_from_folder(
[self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.filename_tmpl)
def __getitem__(self, index):
scale = self.cfg['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
lq_path = self.paths[index]['lq_path']
img_gt = cv2.imread(gt_path).astype(np.float32) / 255.
img_lq = cv2.imread(lq_path).astype(np.float32) / 255.
# augmentation for training
if self.cfg['phase'] == 'train':
gt_size = self.cfg['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'],
self.cfg['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
permute = transform.Permute()
img_gt = permute(img_gt)
img_lq = permute(img_lq)
return {
'lq': img_lq,
'gt': img_gt,
'lq_path': lq_path,
'gt_path': gt_path
}
def __len__(self):
return len(self.paths)
import os import os
import time import time
import copy
import logging import logging
import paddle import paddle
...@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader ...@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs from ..utils.filesystem import save, load, makedirs
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
class Trainer: class Trainer:
def __init__(self, cfg): def __init__(self, cfg):
...@@ -39,12 +40,17 @@ class Trainer: ...@@ -39,12 +40,17 @@ class Trainer:
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval self.visual_interval = cfg.log_config.visiual_interval
self.validate_interval = -1
if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1)
self.cfg = cfg self.cfg = cfg
self.local_rank = ParallelEnv().local_rank self.local_rank = ParallelEnv().local_rank
# time count # time count
self.time_count = {} self.time_count = {}
self.best_metric = {}
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.distributed.prepare_context() strategy = paddle.distributed.prepare_context()
...@@ -78,11 +84,58 @@ class Trainer: ...@@ -78,11 +84,58 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - self.logger.info('train one epoch time: {}'.format(time.time() -
start_time)) start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0: if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1) self.save(epoch, 'weight', keep=-1)
self.save(epoch) self.save(epoch)
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val, is_train=False)
metric_result = {}
for i, data in enumerate(self.val_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
if 'psnr' in self.cfg.validate.metrics:
if 'psnr' not in metric_result:
metric_result['psnr'] = calculate_psnr(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.psnr)
else:
metric_result['psnr'] += calculate_psnr(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.psnr)
if 'ssim' in self.cfg.validate.metrics:
if 'ssim' not in metric_result:
metric_result['ssim'] = calculate_ssim(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.ssim)
else:
metric_result['ssim'] += calculate_ssim(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.ssim)
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
self.logger.info('Epoch {} validate end: {}'.format(self.current_epoch, metric_result))
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, self.test_dataloader = build_dataloader(self.cfg.dataset.test,
...@@ -147,8 +200,11 @@ class Trainer: ...@@ -147,8 +200,11 @@ class Trainer:
msg = '' msg = ''
makedirs(os.path.join(self.output_dir, results_dir)) makedirs(os.path.join(self.output_dir, results_dir))
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
for label, image in visual_results.items(): for label, image in visual_results.items():
image_numpy = tensor2img(image) image_numpy = tensor2img(image, min_max)
img_path = os.path.join(self.output_dir, results_dir, img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label)) msg + '%s.png' % (label))
save_image(image_numpy, img_path) save_image(image_numpy, img_path)
...@@ -210,5 +266,6 @@ class Trainer: ...@@ -210,5 +266,6 @@ class Trainer:
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
self.logger.info('laod model {} {} params!'.format(self.cfg.model.name, 'net' + name))
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name]) net.set_dict(state_dicts['net' + name])
import numpy as np
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
import cv2
import numpy as np
from .metric_util import reorder_image, to_y_channel
def calculate_psnr(img1,
img2,
crop_border,
input_order='HWC',
test_y_channel=False):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
"""
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
def _ssim(img1, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1,
img2,
crop_border,
input_order='HWC',
test_y_channel=False):
"""Calculate SSIM (structural similarity).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
"""
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
ssims = []
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()
from .base_model import BaseModel from .base_model import BaseModel
from .cycle_gan_model import CycleGANModel from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .resnet_backbone import resnet18, resnet34, resnet50, resnet101, resnet152
\ No newline at end of file
from .resnet import ResnetGenerator from .resnet import ResnetGenerator
from .unet import UnetGenerator from .unet import UnetGenerator
\ No newline at end of file from .rrdb_net import RRDBNet
\ No newline at end of file
...@@ -3,10 +3,9 @@ import paddle ...@@ -3,10 +3,9 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from resnet_backbone import resnet34, resnet101 from .hook import hook_outputs, model_sizes, dummy_eval
from hook import hook_outputs, model_sizes, dummy_eval from ..backbones import resnet34, resnet101
from spectral_norm import Spectralnorm from ...modules.nn import Spectralnorm
from paddle import fluid
class SequentialEx(nn.Layer): class SequentialEx(nn.Layer):
...@@ -206,7 +205,7 @@ class UnetBlockWide(nn.Layer): ...@@ -206,7 +205,7 @@ class UnetBlockWide(nn.Layer):
return self.conv(cat_x) return self.conv(cat_x)
class UnetBlockDeep(paddle.fluid.Layer): class UnetBlockDeep(nn.Layer):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
def __init__( def __init__(
...@@ -319,7 +318,7 @@ def conv_layer(ni: int, ...@@ -319,7 +318,7 @@ def conv_layer(ni: int,
return nn.Sequential(*layers) return nn.Sequential(*layers)
class CustomPixelShuffle_ICNR(paddle.fluid.Layer): class CustomPixelShuffle_ICNR(nn.Layer):
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`." "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
def __init__(self, def __init__(self,
...@@ -349,7 +348,7 @@ class CustomPixelShuffle_ICNR(paddle.fluid.Layer): ...@@ -349,7 +348,7 @@ class CustomPixelShuffle_ICNR(paddle.fluid.Layer):
return self.blur(self.pad(x)) if self.blur else x return self.blur(self.pad(x)) if self.blur else x
class MergeLayer(paddle.fluid.Layer): class MergeLayer(nn.Layer):
"Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`." "Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`."
def __init__(self, dense: bool = False): def __init__(self, dense: bool = False):
...@@ -379,7 +378,7 @@ def res_block(nf, ...@@ -379,7 +378,7 @@ def res_block(nf,
MergeLayer(dense)) MergeLayer(dense))
class SigmoidRange(paddle.fluid.Layer): class SigmoidRange(nn.Layer):
"Sigmoid module with range `(low,x_max)`" "Sigmoid module with range `(low,x_max)`"
def __init__(self, low, high): def __init__(self, low, high):
...@@ -395,13 +394,13 @@ def sigmoid_range(x, low, high): ...@@ -395,13 +394,13 @@ def sigmoid_range(x, low, high):
return F.sigmoid(x) * (high - low) + low return F.sigmoid(x) * (high - low) + low
class PixelShuffle(paddle.fluid.Layer): class PixelShuffle(nn.Layer):
def __init__(self, upscale_factor): def __init__(self, upscale_factor):
super(PixelShuffle, self).__init__() super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor self.upscale_factor = upscale_factor
def forward(self, x): def forward(self, x):
return paddle.fluid.layers.pixel_shuffle(x, self.upscale_factor) return F.pixel_shuffle(x, self.upscale_factor)
class ReplicationPad2d(nn.Layer): class ReplicationPad2d(nn.Layer):
...@@ -410,7 +409,7 @@ class ReplicationPad2d(nn.Layer): ...@@ -410,7 +409,7 @@ class ReplicationPad2d(nn.Layer):
self.size = size self.size = size
def forward(self, x): def forward(self, x):
return paddle.fluid.layers.pad2d(x, self.size, mode="edge") return F.pad2d(x, self.size, mode="edge")
def conv1d(ni: int, def conv1d(ni: int,
......
...@@ -3,6 +3,8 @@ import paddle ...@@ -3,6 +3,8 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from .builder import GENERATORS
class ResidualDenseBlock_5C(nn.Layer): class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True): def __init__(self, nf=64, gc=32, bias=True):
...@@ -15,6 +17,7 @@ class ResidualDenseBlock_5C(nn.Layer): ...@@ -15,6 +17,7 @@ class ResidualDenseBlock_5C(nn.Layer):
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2) self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x): def forward(self, x):
x1 = self.lrelu(self.conv1(x)) x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1)))
...@@ -26,6 +29,7 @@ class ResidualDenseBlock_5C(nn.Layer): ...@@ -26,6 +29,7 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer): class RRDB(nn.Layer):
'''Residual in Residual Dense Block''' '''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32): def __init__(self, nf, gc=32):
super(RRDB, self).__init__() super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB1 = ResidualDenseBlock_5C(nf, gc)
...@@ -38,7 +42,6 @@ class RRDB(nn.Layer): ...@@ -38,7 +42,6 @@ class RRDB(nn.Layer):
out = self.RDB3(out) out = self.RDB3(out)
return out * 0.2 + x return out * 0.2 + x
def make_layer(block, n_layers): def make_layer(block, n_layers):
layers = [] layers = []
for _ in range(n_layers): for _ in range(n_layers):
...@@ -46,6 +49,7 @@ def make_layer(block, n_layers): ...@@ -46,6 +49,7 @@ def make_layer(block, n_layers):
return nn.Sequential(*layers) return nn.Sequential(*layers)
@GENERATORS.register()
class RRDBNet(nn.Layer): class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32): def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
...@@ -67,10 +71,8 @@ class RRDBNet(nn.Layer): ...@@ -67,10 +71,8 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea)) trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk fea = fea + trunk
fea = self.lrelu( fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea))) out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out return out
from collections import OrderedDict
import paddle
import paddle.nn as nn
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_optimizer
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
import importlib
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
from .builder import MODELS
@MODELS.register()
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, cfg):
super(SRModel, self).__init__(cfg)
self.model_names = ['G']
self.netG = build_generator(cfg.model.generator)
self.visual_names = ['lq', 'output', 'gt']
self.loss_names = ['l_total']
self.optimizers = []
if self.isTrain:
self.criterionL1 = paddle.nn.L1Loss()
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.image_paths = input['lq_path']
def forward(self):
pass
def test(self):
"""Forward function used in test time.
"""
with paddle.no_grad():
self.output = self.netG(self.lq)
def optimize_parameters(self):
self.optimizer_G.clear_grad()
self.output = self.netG(self.lq)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.criterionL1:
l_pix = self.criterionL1(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
l_total.backward()
self.loss_l_total = l_total
self.optimizer_G.step()
from collections import OrderedDict
import paddle
import paddle.nn as nn
from .generators.builder import build_generator
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
@MODELS.register()
class SRGANModel(BaseModel):
def __init__(self, cfg):
super(SRGANModel, self).__init__(cfg)
# define networks
self.model_names = ['G']
self.netG = build_generator(cfg.model.generator)
self.visual_names = ['LQ', 'GT', 'fake_H']
# TODO: support srgan train.
if False:
# self.netD = build_discriminator(cfg.model.discriminator)
self.netG.train()
# self.netD.train()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap images in domain A and domain B.
"""
# AtoB = self.opt.dataset.train.direction == 'AtoB'
if 'A' in input:
self.LQ = paddle.to_tensor(input['A'])
if 'B' in input:
self.GT = paddle.to_tensor(input['B'])
if 'A_paths' in input:
self.image_paths = input['A_paths']
def forward(self):
self.fake_H = self.netG(self.LQ)
def optimize_parameters(self, step):
pass
...@@ -69,21 +69,59 @@ class BCEWithLogitsLoss(): ...@@ -69,21 +69,59 @@ class BCEWithLogitsLoss():
return out return out
# class BCEWithLogitsLoss(fluid.dygraph.Layer): class _SpectralNorm(paddle.nn.SpectralNorm):
# def __init__(self, weight=None, reduction='mean'): def __init__(self,
# if reduction not in ['sum', 'mean', 'none']: weight_shape,
# raise ValueError( dim=0,
# "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but " power_iters=1,
# "received %s, which is not allowed." % reduction) eps=1e-12,
dtype='float32'):
# super(BCEWithLogitsLoss, self).__init__() super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype)
# # self.weight = weight
# # self.reduction = reduction def forward(self, weight):
# self.bce_loss = paddle.nn.BCELoss(weight, reduction) paddle.fluid.data_feeder.check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
# def forward(self, input, label): inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
# input = paddle.nn.functional.sigmoid(input, True) out = self._helper.create_variable_for_type_inference(self._dtype)
# return self.bce_loss(input, label) _power_iters = self._power_iters if self.training else 0
self._helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={"Out": out, },
attrs={
"dim": self._dim,
"power_iters": _power_iters,
"eps": self._eps,
})
return out
class Spectralnorm(paddle.nn.Layer):
def __init__(self,
layer,
dim=0,
power_iters=1,
eps=1e-12,
dtype='float32'):
super(Spectralnorm, self).__init__()
self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype)
self.dim = dim
self.power_iters = power_iters
self.eps = eps
self.layer = layer
weight = layer._parameters['weight']
del layer._parameters['weight']
self.weight_orig = self.create_parameter(weight.shape, dtype=weight.dtype)
self.weight_orig.set_value(weight)
def forward(self, x):
weight = self.spectral_norm(self.weight_orig)
self.layer.weight = weight
out = self.layer(x)
return out
def initial_type( def initial_type(
......
import os
import sys
def video2frames(video_path, outpath, **kargs):
def _dict2str(kargs):
cmd_str = ''
for k, v in kargs.items():
cmd_str += (' ' + str(k) + ' ' + str(v))
return cmd_str
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = video_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
cmd = ffmpeg
cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd) + _dict2str(kargs)
if os.system(cmd) != 0:
raise RuntimeError('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
def frames2video(frame_path, video_path, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', video_path
]
cmd = ''.join(cmd)
if os.system(cmd) != 0:
raise RuntimeError('ffmpeg process video: {} error'.format(video_path))
sys.stdout.flush()
\ No newline at end of file
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from PIL import Image from PIL import Image
def tensor2img(input_image, imtype=np.uint8): def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8):
""""Converts a Tensor array into a numpy image array. """"Converts a Tensor array into a numpy image array.
Parameters: Parameters:
...@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8): ...@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8):
image_numpy = image_numpy[0] image_numpy = image_numpy[0]
if image_numpy.shape[0] == 1: # grayscale to RGB if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling image_numpy = image_numpy.clip(min_max[0], min_max[1])
image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0])
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing else: # if it is a numpy array, do nothing
image_numpy = input_image image_numpy = input_image
return image_numpy.astype(imtype) return image_numpy.astype(imtype)
......
tqdm
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
from io import open
with open('requirments.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
def readme():
with open('doc/doc_en/whl_en.md', encoding="utf-8-sig") as f:
README = f.read()
return README
setup(
name='ppgan',
packages=['ppgan'],
include_package_data=True,
entry_points={"console_scripts": ["paddlegan= paddlegan.paddlegan:main"]},
version='0.1.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome GAN toolkits based on PaddlePaddle',
url='https://github.com/PaddlePaddle/PaddleGAN',
download_url='https://github.com/PaddlePaddle/PaddleGAN.git',
keywords=[
'gan paddlegan'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册