diff --git a/applications/DAIN/predict.py b/applications/DAIN/predict.py index 38c1d6baa4c833da966baf30b0ef5ee5e7e4f4ef..f43a794150d52e452fdb7d5c3aadcc93cbbda3da 100644 --- a/applications/DAIN/predict.py +++ b/applications/DAIN/predict.py @@ -13,8 +13,8 @@ import cv2 import paddle.fluid as fluid from paddle.utils.download import get_path_from_url +from ppgan.utils.video import video2frames, frames2video -import networks from util import * from my_args import parser @@ -129,7 +129,7 @@ class VideoFrameInterp(object): r2 = str(int(fps) * times_interp) print("New fps (frame rate): ", r2) - out_path = dump_frames_ffmpeg(vid, frame_path_input) + out_path = video2frames(vid, frame_path_input) vidname = vid.split('/')[-1].split('.')[0] @@ -266,7 +266,7 @@ class VideoFrameInterp(object): vidname + '.mp4') if os.path.exists(video_pattern_output): os.remove(video_pattern_output) - frames_to_video_ffmpeg(frame_pattern_combined, video_pattern_output, + frames2video(frame_pattern_combined, video_pattern_output, r2) return frame_pattern_combined, video_pattern_output diff --git a/applications/DAIN/util.py b/applications/DAIN/util.py index 3efbfe0dc7cac0aeed1c624af9c192c381b4fdc5..24ea2741517660581c12d8b174e3e8af03ae9a8e 100644 --- a/applications/DAIN/util.py +++ b/applications/DAIN/util.py @@ -21,66 +21,6 @@ class AverageMeter(object): self.avg = self.sum / self.count -def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - vid_name = vid_path.split('/')[-1].split('.')[0] - out_full_path = os.path.join(outpath, vid_name) - - if not os.path.exists(out_full_path): - os.makedirs(out_full_path) - - # video file name - outformat = out_full_path + '/%08d.png' - - if ss is not None and t is not None and r is not None: - cmd = ffmpeg + [ - ' -ss ', - ss, - ' -t ', - t, - ' -i ', - vid_path, - ' -r ', - r, - # ' -f ', ' image2 ', - # ' -s ', ' 960*540 ', - ' -qscale:v ', - ' 0.1 ', - ' -start_number ', - ' 0 ', - # ' -qmax ', ' 1 ', - outformat - ] - else: - cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] - - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(vid_name)) - - sys.stdout.flush() - return out_full_path - - -def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - cmd = ffmpeg + [ - ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', - ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath - ] - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(videopath)) - - sys.stdout.flush() - - def combine_frames(input, interpolated, combined, num_frames): frames1 = sorted(glob.glob(os.path.join(input, '*.png'))) frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png'))) diff --git a/applications/DeOldify/predict.py b/applications/DeOldify/predict.py index ce637fefb2c49034112a5b59dd657379bd31e8ae..fd94970f38adac3ed494ab0e3208aaa81cd4755f 100644 --- a/applications/DeOldify/predict.py +++ b/applications/DeOldify/predict.py @@ -14,8 +14,9 @@ import pickle from PIL import Image from tqdm import tqdm from paddle import fluid -from model import build_model from paddle.utils.download import get_path_from_url +from ppgan.utils.video import frames2video, video2frames +from ppgan.models.generators.deoldify import build_model parser = argparse.ArgumentParser(description='DeOldify') parser.add_argument('--input', type=str, default='none', help='Input video') @@ -29,23 +30,7 @@ parser.add_argument('--weight_path', default=None, help='Path to the reference image directory') -DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' - - -def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - cmd = ffmpeg + [ - ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', - ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath - ] - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(videopath)) - - sys.stdout.flush() +DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' class DeOldifyPredictor(): @@ -60,7 +45,7 @@ class DeOldifyPredictor(): self.render_factor = render_factor self.model = build_model() if weight_path is None: - weight_path = get_path_from_url(DeOldify_weight_url, cur_path) + weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path) state_dict, _ = paddle.load(weight_path) self.model.load_dict(state_dict) @@ -127,7 +112,7 @@ class DeOldifyPredictor(): cap = cv2.VideoCapture(vid) fps = cap.get(cv2.CAP_PROP_FPS) - out_path = dump_frames_ffmpeg(vid, output_path) + out_path = video2frames(vid, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) @@ -141,42 +126,11 @@ class DeOldifyPredictor(): vid_out_path = os.path.join(output_path, '{}_deoldify_out.mp4'.format(base_name)) - frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, - str(int(fps))) + frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path -def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - vid_name = vid_path.split('/')[-1].split('.')[0] - out_full_path = os.path.join(outpath, 'frames_input') - - if not os.path.exists(out_full_path): - os.makedirs(out_full_path) - - # video file name - outformat = out_full_path + '/%08d.png' - - if ss is not None and t is not None and r is not None: - cmd = ffmpeg + [ - ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', - ' 0.1 ', ' -start_number ', ' 0 ', outformat - ] - else: - cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] - - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(vid_name)) - - sys.stdout.flush() - return out_full_path - - if __name__ == '__main__': paddle.disable_static() args = parser.parse_args() diff --git a/applications/DeOldify/spectral_norm.py b/applications/DeOldify/spectral_norm.py deleted file mode 100644 index 81500a51d48c46ad8f0628898209fad829f0c67e..0000000000000000000000000000000000000000 --- a/applications/DeOldify/spectral_norm.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -from paddle import fluid -from paddle.fluid import dygraph -from paddle.fluid import layers as F -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.data_feeder import check_variable_and_dtype - -import paddle -import paddle.nn as nn - -class _SpectralNorm(nn.SpectralNorm): - def __init__(self, - weight_shape, - dim=0, - power_iters=1, - eps=1e-12, - dtype='float32'): - super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype) - - def forward(self, weight): - check_variable_and_dtype(weight, "weight", ['float32', 'float64'], - 'SpectralNorm') - inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} - out = self._helper.create_variable_for_type_inference(self._dtype) - _power_iters = self._power_iters if self.training else 0 - self._helper.append_op( - type="spectral_norm", - inputs=inputs, - outputs={"Out": out, }, - attrs={ - "dim": self._dim, - "power_iters": _power_iters, #self._power_iters, - "eps": self._eps, - }) - - return out - - -class Spectralnorm(nn.Layer): - - def __init__(self, - layer, - dim=0, - power_iters=1, - eps=1e-12, - dtype='float32'): - super(Spectralnorm, self).__init__() - self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype) - self.dim = dim - self.power_iters = power_iters - self.eps = eps - self.layer = layer - weight = layer._parameters['weight'] - del layer._parameters['weight'] - self.weight_orig = self.create_parameter(weight.shape, dtype=weight.dtype) - self.weight_orig.set_value(weight) - - - def forward(self, x): - weight = self.spectral_norm(self.weight_orig) - self.layer.weight = weight - out = self.layer(x) - return out diff --git a/applications/DeepRemaster/predict.py b/applications/DeepRemaster/predict.py index 3ad54b31eff7bb14f630e0c1e66970f96a12e173..baa8b7fd68d4629d57f8d5e31b405fe45b049dd1 100644 --- a/applications/DeepRemaster/predict.py +++ b/applications/DeepRemaster/predict.py @@ -14,10 +14,10 @@ from tqdm import tqdm import argparse import subprocess import utils -from remasternet import NetworkR, NetworkC +from ppgan.models.generators.remaster import NetworkR, NetworkC from paddle.utils.download import get_path_from_url -DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' +DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' parser = argparse.ArgumentParser(description='Remastering') parser.add_argument('--input', type=str, default=None, help='Input video') @@ -51,7 +51,7 @@ class DeepReasterPredictor: self.mindim = mindim if weight_path is None: - weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) + weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path) state_dict, _ = paddle.load(weight_path) diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index 11ab8928e877b36ef236fd73e15bf0ef381ded39..5f95714cea667555d64ddfea83d75c8191b773b0 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -30,8 +30,9 @@ import cv2 from tqdm import tqdm from data import EDVRDataset from paddle.utils.download import get_path_from_url +from ppgan.utils.video import frames2video, video2frames -EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' +EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' def parse_args(): @@ -71,52 +72,6 @@ def save_img(img, framename): cv2.imwrite(framename, img) -def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - vid_name = vid_path.split('/')[-1].split('.')[0] - out_full_path = os.path.join(outpath, 'frames_input') - - if not os.path.exists(out_full_path): - os.makedirs(out_full_path) - - # video file name - outformat = out_full_path + '/%08d.png' - - if ss is not None and t is not None and r is not None: - cmd = ffmpeg + [ - ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', - ' 0.1 ', ' -start_number ', ' 0 ', outformat - ] - else: - cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] - - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(vid_name)) - - sys.stdout.flush() - return out_full_path - - -def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - cmd = ffmpeg + [ - ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', - ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath - ] - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(videopath)) - - sys.stdout.flush() - - class EDVRPredictor: def __init__(self, input, output, weight_path=None): self.input = input @@ -127,9 +82,7 @@ class EDVRPredictor: self.exe = fluid.Executor(place) if weight_path is None: - weight_path = get_path_from_url(EDVR_weight_url, cur_path) - - print(weight_path) + weight_path = get_path_from_url(EDVR_WEIGHT_URL, cur_path) model_filename = 'EDVR_model.pdmodel' params_filename = 'EDVR_params.pdparams' @@ -155,7 +108,7 @@ class EDVRPredictor: cap = cv2.VideoCapture(vid) fps = cap.get(cv2.CAP_PROP_FPS) - out_path = dump_frames_ffmpeg(vid, output_path) + out_path = video2frames(vid, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) @@ -188,8 +141,7 @@ class EDVRPredictor: frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) - frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, - str(int(fps))) + frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path diff --git a/applications/RealSR/predict.py b/applications/RealSR/predict.py index a05bf788e0f11ac73119b0f31c49a3615764c878..d032bc2a78029e174d80d0d124036db75a42d0cb 100644 --- a/applications/RealSR/predict.py +++ b/applications/RealSR/predict.py @@ -13,7 +13,9 @@ import pickle from PIL import Image from tqdm import tqdm -from sr_model import RRDBNet + +from ppgan.models.generators import RRDBNet +from ppgan.utils.video import frames2video, video2frames from paddle.utils.download import get_path_from_url parser = argparse.ArgumentParser(description='RealSR') @@ -24,23 +26,7 @@ parser.add_argument('--weight_path', default=None, help='Path to the reference image directory') -RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' - - -def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - cmd = ffmpeg + [ - ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', - ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath - ] - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(videopath)) - - sys.stdout.flush() +REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' class RealSRPredictor(): @@ -49,7 +35,7 @@ class RealSRPredictor(): self.output = os.path.join(output, 'RealSR') self.model = RRDBNet(3, 3, 64, 23) if weight_path is None: - weight_path = get_path_from_url(RealSR_weight_url, cur_path) + weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path) state_dict, _ = paddle.load(weight_path) self.model.load_dict(state_dict) @@ -88,7 +74,7 @@ class RealSRPredictor(): cap = cv2.VideoCapture(vid) fps = cap.get(cv2.CAP_PROP_FPS) - out_path = dump_frames_ffmpeg(vid, output_path) + out_path = video2frames(vid, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) @@ -102,42 +88,11 @@ class RealSRPredictor(): vid_out_path = os.path.join(output_path, '{}_realsr_out.mp4'.format(base_name)) - frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, - str(int(fps))) + frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path -def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] - vid_name = vid_path.split('/')[-1].split('.')[0] - out_full_path = os.path.join(outpath, 'frames_input') - - if not os.path.exists(out_full_path): - os.makedirs(out_full_path) - - # video file name - outformat = out_full_path + '/%08d.png' - - if ss is not None and t is not None and r is not None: - cmd = ffmpeg + [ - ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', - ' 0.1 ', ' -start_number ', ' 0 ', outformat - ] - else: - cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] - - cmd = ''.join(cmd) - - if os.system(cmd) == 0: - pass - else: - print('ffmpeg process video: {} error'.format(vid_name)) - - sys.stdout.flush() - return out_full_path - - if __name__ == '__main__': paddle.disable_static() args = parser.parse_args() diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index c4facd8bc817f5c3ffb450feae0c221880aa5dff..f74d9e3bdd35e3a521b3056a5b268f67bba2e406 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -41,6 +41,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: False + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) test: name: SingleDataset dataroot: data/cityscapes/testB @@ -55,6 +60,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: True + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) optimizer: diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 1ea5c6d1687c35197a1470833b33cdefcf0ba5ee..0e845bd5183428f7c166bae300f74757406c07f5 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -40,6 +40,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: False + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) test: name: SingleDataset dataroot: data/horse2zebra/testA @@ -54,7 +59,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: True - + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 06577f7f1a50301431832b541b17c78802e35077..5919ff2e5a5c2c267a9204d117dc7aba5fb245a7 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -38,6 +38,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: False + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) test: name: PairedDataset dataroot: data/cityscapes/ @@ -53,6 +58,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: True + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) optimizer: name: Adam diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index a64b57a8c5e7bd80c71109bef58b8d8bf17fffff..20f494c6fb13690254dd2d047df8c8970615ebff 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -37,6 +37,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: False + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) test: name: PairedDataset dataroot: data/cityscapes/ @@ -52,6 +57,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: True + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) optimizer: name: Adam diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index ede78386fdd09d6f67d797d658c494892c316fd9..31b5f145dccdfd75bbdcd14c3fa896676d729037 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -37,6 +37,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: False + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) test: name: PairedDataset dataroot: data/facades/ @@ -52,6 +57,11 @@ dataset: crop_size: 256 preprocess: resize_and_crop no_flip: True + normalize: + mean: + (127.5, 127.5, 127.5) + std: + (127.5, 127.5, 127.5) optimizer: name: Adam diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 9b807e9be0c83dda6415ebf01418cc77b8f463ba..0aeb70936b58125fb92d00ce5905e2608142f728 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -1,3 +1,4 @@ from .unpaired_dataset import UnpairedDataset from .single_dataset import SingleDataset from .paired_dataset import PairedDataset +from .sr_image_dataset import SRImageDataset \ No newline at end of file diff --git a/ppgan/datasets/base_dataset.py b/ppgan/datasets/base_dataset.py index 87e996925477c5fa096df48779327397ce22873b..fe93c71e718faaff1c019db9fb6632509d3db4f1 100644 --- a/ppgan/datasets/base_dataset.py +++ b/ppgan/datasets/base_dataset.py @@ -94,7 +94,9 @@ def get_transform(cfg, if convert: transform_list += [transforms.Permute(to_rgb=True)] - transform_list += [ - transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) - ] + if cfg.get('normalize', None): + transform_list += [ + transforms.Normalize(cfg.normalize.mean, cfg.normalize.std) + ] + return transforms.Compose(transform_list) diff --git a/ppgan/datasets/sr_image_dataset.py b/ppgan/datasets/sr_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..977d2657ce10b1b022bb3157d067616425fd7ef3 --- /dev/null +++ b/ppgan/datasets/sr_image_dataset.py @@ -0,0 +1,243 @@ +# import mmcv +import os +import cv2 +import random +import numpy as np +import paddle.vision.transforms as transform + +from pathlib import Path +from paddle.io import Dataset +from .builder import DATASETS + + +def scandir(dir_path, suffix=None, recursive=False): + """Scan a directory to find the interested files. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = os.path.relpath(entry.path, root) + if suffix is None: + yield rel_path + elif rel_path.endswith(suffix): + yield rel_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + """ + assert len(folders) == 2, ( + 'The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ( + 'The len of keys should be 2 with [input_key, gt_key]. ' + f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), ( + f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = os.path.splitext(os.path.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = os.path.join(input_folder, input_name) + assert input_name in input_paths, (f'{input_name} is not in ' + f'{input_key}_paths.') + gt_path = os.path.join(gt_folder, gt_path) + paths.append( + dict([(f'{input_key}_path', input_path), + (f'{gt_key}_path', gt_path)])) + return paths + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError( + f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [ + v[top:top + lq_patch_size, left:left + lq_patch_size, ...] + for v in img_lqs + ] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [ + v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] + for v in img_gts + ] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: + cv2.flip(img, 1, img) + if vflip: + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + return imgs + + +@DATASETS.register() +class SRImageDataset(Dataset): + """Paired image dataset for image restoration.""" + + def __init__(self, cfg): + super(SRImageDataset, self).__init__() + self.cfg = cfg + + self.file_client = None + self.io_backend_opt = cfg['io_backend'] + + self.gt_folder, self.lq_folder = cfg['dataroot_gt'], cfg['dataroot_lq'] + if 'filename_tmpl' in cfg: + self.filename_tmpl = cfg['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + #TODO: LielinJiang support lmdb to accelerate io + pass + elif 'meta_info_file' in self.cfg and self.cfg[ + 'meta_info_file'] is not None: + #TODO: LielinJiang support lmdb to accelerate io + pass + else: + self.paths = paired_paths_from_folder( + [self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.filename_tmpl) + + def __getitem__(self, index): + scale = self.cfg['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + lq_path = self.paths[index]['lq_path'] + + img_gt = cv2.imread(gt_path).astype(np.float32) / 255. + img_lq = cv2.imread(lq_path).astype(np.float32) / 255. + + # augmentation for training + if self.cfg['phase'] == 'train': + gt_size = self.cfg['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, + gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.cfg['use_flip'], + self.cfg['use_rot']) + + # TODO: color space transform + # BGR to RGB, HWC to CHW, numpy to tensor + permute = transform.Permute() + img_gt = permute(img_gt) + img_lq = permute(img_lq) + return { + 'lq': img_lq, + 'gt': img_gt, + 'lq_path': lq_path, + 'gt_path': gt_path + } + + def __len__(self): + return len(self.paths) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 2e3b21b1a3456b3b7cdbcd76cb5248b4ad4bd763..650aab765e4841e0a4e62d8bb6f6404a37b69d66 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -1,5 +1,6 @@ import os import time +import copy import logging import paddle @@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader from ..models.builder import build_model from ..utils.visual import tensor2img, save_image from ..utils.filesystem import save, load, makedirs - +from ..metric.psnr_ssim import calculate_psnr, calculate_ssim class Trainer: def __init__(self, cfg): @@ -39,12 +40,17 @@ class Trainer: self.weight_interval = cfg.snapshot_config.interval self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval + self.validate_interval = -1 + if cfg.get('validate', None) is not None: + self.validate_interval = cfg.validate.get('interval', -1) self.cfg = cfg self.local_rank = ParallelEnv().local_rank # time count self.time_count = {} + self.best_metric = {} + def distributed_data_parallel(self): strategy = paddle.distributed.prepare_context() @@ -78,11 +84,58 @@ class Trainer: step_start_time = time.time() self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) + if self.validate_interval > -1 and epoch % self.validate_interval: + self.validate() self.model.lr_scheduler.step() if epoch % self.weight_interval == 0: self.save(epoch, 'weight', keep=-1) self.save(epoch) + def validate(self): + if not hasattr(self, 'val_dataloader'): + self.val_dataloader = build_dataloader(self.cfg.dataset.val, is_train=False) + + metric_result = {} + + for i, data in enumerate(self.val_dataloader): + self.batch_id = i + + self.model.set_input(data) + self.model.test() + + visual_results = {} + current_paths = self.model.get_image_paths() + current_visuals = self.model.get_current_visuals() + + for j in range(len(current_paths)): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + for k, img_tensor in current_visuals.items(): + name = '%s_%s' % (basename, k) + visual_results.update({name: img_tensor[j]}) + if 'psnr' in self.cfg.validate.metrics: + if 'psnr' not in metric_result: + metric_result['psnr'] = calculate_psnr(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.psnr) + else: + metric_result['psnr'] += calculate_psnr(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.psnr) + if 'ssim' in self.cfg.validate.metrics: + if 'ssim' not in metric_result: + metric_result['ssim'] = calculate_ssim(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.ssim) + else: + metric_result['ssim'] += calculate_ssim(tensor2img(current_visuals['output'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.ssim) + + self.visual('visual_val', visual_results=visual_results) + + if i % self.log_interval == 0: + self.logger.info('val iter: [%d/%d]' % + (i, len(self.val_dataloader))) + + for metric_name in metric_result.keys(): + metric_result[metric_name] /= len(self.val_dataloader.dataset) + + self.logger.info('Epoch {} validate end: {}'.format(self.current_epoch, metric_result)) + + def test(self): if not hasattr(self, 'test_dataloader'): self.test_dataloader = build_dataloader(self.cfg.dataset.test, @@ -147,8 +200,11 @@ class Trainer: msg = '' makedirs(os.path.join(self.output_dir, results_dir)) + min_max = self.cfg.get('min_max', None) + if min_max is None: + min_max = (-1., 1.) for label, image in visual_results.items(): - image_numpy = tensor2img(image) + image_numpy = tensor2img(image, min_max) img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label)) save_image(image_numpy, img_path) @@ -210,5 +266,6 @@ class Trainer: for name in self.model.model_names: if isinstance(name, str): + self.logger.info('laod model {} {} params!'.format(self.cfg.model.name, 'net' + name)) net = getattr(self.model, 'net' + name) net.set_dict(state_dicts['net' + name]) diff --git a/ppgan/metric/metric_util.py b/ppgan/metric/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d81c2f64086b219f045a58d65463b0099a793aad --- /dev/null +++ b/ppgan/metric/metric_util.py @@ -0,0 +1,78 @@ +import numpy as np + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + return img + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/ppgan/metric/psnr_ssim.py b/ppgan/metric/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..65cfde310cc1e2b8149443cdd9c85330e278509e --- /dev/null +++ b/ppgan/metric/psnr_ssim.py @@ -0,0 +1,137 @@ +import cv2 +import numpy as np + +from .metric_util import reorder_image, to_y_channel + + +def calculate_psnr(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 621e8edf337837440e679d9969ffd614520268e2..1fb4e96098b6ea230c029c8c0f0ff7ad2eb5b139 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -1,4 +1,6 @@ from .base_model import BaseModel from .cycle_gan_model import CycleGANModel from .pix2pix_model import Pix2PixModel +from .srgan_model import SRGANModel +from .sr_model import SRModel diff --git a/ppgan/models/backbones/__init__.py b/ppgan/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c876f93de18a3a11b0946e65dc1f610d73bf7ec9 --- /dev/null +++ b/ppgan/models/backbones/__init__.py @@ -0,0 +1 @@ +from .resnet_backbone import resnet18, resnet34, resnet50, resnet101, resnet152 \ No newline at end of file diff --git a/applications/DeOldify/resnet_backbone.py b/ppgan/models/backbones/resnet_backbone.py similarity index 100% rename from applications/DeOldify/resnet_backbone.py rename to ppgan/models/backbones/resnet_backbone.py diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 840d716a1399053dbb621e4d4ede60aec4afde0b..15ac59d156f852e10fa4263fb4fd5b1fe9f7a976 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -1,2 +1,3 @@ from .resnet import ResnetGenerator -from .unet import UnetGenerator \ No newline at end of file +from .unet import UnetGenerator +from .rrdb_net import RRDBNet \ No newline at end of file diff --git a/applications/DeOldify/model.py b/ppgan/models/generators/deoldify.py similarity index 96% rename from applications/DeOldify/model.py rename to ppgan/models/generators/deoldify.py index 9f97ed8667a70c248c8d6d075e4b4c7f05f186d0..b7f875364dee3edfedf98c4a9bf89c0a50dd5ad9 100644 --- a/applications/DeOldify/model.py +++ b/ppgan/models/generators/deoldify.py @@ -3,10 +3,9 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from resnet_backbone import resnet34, resnet101 -from hook import hook_outputs, model_sizes, dummy_eval -from spectral_norm import Spectralnorm -from paddle import fluid +from .hook import hook_outputs, model_sizes, dummy_eval +from ..backbones import resnet34, resnet101 +from ...modules.nn import Spectralnorm class SequentialEx(nn.Layer): @@ -206,7 +205,7 @@ class UnetBlockWide(nn.Layer): return self.conv(cat_x) -class UnetBlockDeep(paddle.fluid.Layer): +class UnetBlockDeep(nn.Layer): "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." def __init__( @@ -319,7 +318,7 @@ def conv_layer(ni: int, return nn.Sequential(*layers) -class CustomPixelShuffle_ICNR(paddle.fluid.Layer): +class CustomPixelShuffle_ICNR(nn.Layer): "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`." def __init__(self, @@ -349,7 +348,7 @@ class CustomPixelShuffle_ICNR(paddle.fluid.Layer): return self.blur(self.pad(x)) if self.blur else x -class MergeLayer(paddle.fluid.Layer): +class MergeLayer(nn.Layer): "Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`." def __init__(self, dense: bool = False): @@ -379,7 +378,7 @@ def res_block(nf, MergeLayer(dense)) -class SigmoidRange(paddle.fluid.Layer): +class SigmoidRange(nn.Layer): "Sigmoid module with range `(low,x_max)`" def __init__(self, low, high): @@ -395,13 +394,13 @@ def sigmoid_range(x, low, high): return F.sigmoid(x) * (high - low) + low -class PixelShuffle(paddle.fluid.Layer): +class PixelShuffle(nn.Layer): def __init__(self, upscale_factor): super(PixelShuffle, self).__init__() self.upscale_factor = upscale_factor def forward(self, x): - return paddle.fluid.layers.pixel_shuffle(x, self.upscale_factor) + return F.pixel_shuffle(x, self.upscale_factor) class ReplicationPad2d(nn.Layer): @@ -410,7 +409,7 @@ class ReplicationPad2d(nn.Layer): self.size = size def forward(self, x): - return paddle.fluid.layers.pad2d(x, self.size, mode="edge") + return F.pad2d(x, self.size, mode="edge") def conv1d(ni: int, diff --git a/applications/DeOldify/hook.py b/ppgan/models/generators/hook.py similarity index 100% rename from applications/DeOldify/hook.py rename to ppgan/models/generators/hook.py diff --git a/applications/DeepRemaster/remasternet.py b/ppgan/models/generators/remaster.py similarity index 100% rename from applications/DeepRemaster/remasternet.py rename to ppgan/models/generators/remaster.py diff --git a/applications/RealSR/sr_model.py b/ppgan/models/generators/rrdb_net.py similarity index 91% rename from applications/RealSR/sr_model.py rename to ppgan/models/generators/rrdb_net.py index c8a730bea00c3b512473785290bc27a0744cf7e0..008da739e38e170b1e02f629dba1576ce09a9723 100644 --- a/applications/RealSR/sr_model.py +++ b/ppgan/models/generators/rrdb_net.py @@ -3,6 +3,8 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F +from .builder import GENERATORS + class ResidualDenseBlock_5C(nn.Layer): def __init__(self, nf=64, gc=32, bias=True): @@ -15,6 +17,7 @@ class ResidualDenseBlock_5C(nn.Layer): self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2) + def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) @@ -26,6 +29,7 @@ class ResidualDenseBlock_5C(nn.Layer): class RRDB(nn.Layer): '''Residual in Residual Dense Block''' + def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) @@ -38,7 +42,6 @@ class RRDB(nn.Layer): out = self.RDB3(out) return out * 0.2 + x - def make_layer(block, n_layers): layers = [] for _ in range(n_layers): @@ -46,6 +49,7 @@ def make_layer(block, n_layers): return nn.Sequential(*layers) +@GENERATORS.register() class RRDBNet(nn.Layer): def __init__(self, in_nc, out_nc, nf, nb, gc=32): super(RRDBNet, self).__init__() @@ -67,10 +71,8 @@ class RRDBNet(nn.Layer): trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk - fea = self.lrelu( - self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu( - self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f912431d3f8b92e39702a3ac960db8f9a3b166 --- /dev/null +++ b/ppgan/models/sr_model.py @@ -0,0 +1,72 @@ +from collections import OrderedDict +import paddle +import paddle.nn as nn + +from .generators.builder import build_generator +from .discriminators.builder import build_discriminator +from ..solver import build_optimizer +from .base_model import BaseModel +from .losses import GANLoss +from .builder import MODELS + +import importlib +from collections import OrderedDict +from copy import deepcopy +from os import path as osp +from .builder import MODELS + + +@MODELS.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + def __init__(self, cfg): + super(SRModel, self).__init__(cfg) + + self.model_names = ['G'] + + self.netG = build_generator(cfg.model.generator) + self.visual_names = ['lq', 'output', 'gt'] + + self.loss_names = ['l_total'] + + self.optimizers = [] + if self.isTrain: + self.criterionL1 = paddle.nn.L1Loss() + + self.build_lr_scheduler() + self.optimizer_G = build_optimizer( + cfg.optimizer, + self.lr_scheduler, + parameter_list=self.netG.parameters()) + self.optimizers.append(self.optimizer_G) + + def set_input(self, input): + self.lq = paddle.to_tensor(input['lq']) + if 'gt' in input: + self.gt = paddle.to_tensor(input['gt']) + self.image_paths = input['lq_path'] + + def forward(self): + pass + + def test(self): + """Forward function used in test time. + """ + with paddle.no_grad(): + self.output = self.netG(self.lq) + + def optimize_parameters(self): + self.optimizer_G.clear_grad() + self.output = self.netG(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.criterionL1: + l_pix = self.criterionL1(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + + l_total.backward() + self.loss_l_total = l_total + self.optimizer_G.step() diff --git a/ppgan/models/srgan_model.py b/ppgan/models/srgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..32ca581f8fcaac86fe3bf1bfeaca7617ecc0e06a --- /dev/null +++ b/ppgan/models/srgan_model.py @@ -0,0 +1,49 @@ +from collections import OrderedDict +import paddle +import paddle.nn as nn + +from .generators.builder import build_generator +from .base_model import BaseModel +from .losses import GANLoss +from .builder import MODELS + + +@MODELS.register() +class SRGANModel(BaseModel): + def __init__(self, cfg): + super(SRGANModel, self).__init__(cfg) + + # define networks + self.model_names = ['G'] + + self.netG = build_generator(cfg.model.generator) + self.visual_names = ['LQ', 'GT', 'fake_H'] + + # TODO: support srgan train. + if False: + # self.netD = build_discriminator(cfg.model.discriminator) + self.netG.train() + # self.netD.train() + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + + # AtoB = self.opt.dataset.train.direction == 'AtoB' + if 'A' in input: + self.LQ = paddle.to_tensor(input['A']) + if 'B' in input: + self.GT = paddle.to_tensor(input['B']) + if 'A_paths' in input: + self.image_paths = input['A_paths'] + + def forward(self): + self.fake_H = self.netG(self.LQ) + + def optimize_parameters(self, step): + pass diff --git a/ppgan/modules/nn.py b/ppgan/modules/nn.py index 9620877a449656c687cd29e3410ec571ea224f3c..f867b7284efc5db50e9b7f22dfd6d6c31073e682 100644 --- a/ppgan/modules/nn.py +++ b/ppgan/modules/nn.py @@ -69,21 +69,59 @@ class BCEWithLogitsLoss(): return out -# class BCEWithLogitsLoss(fluid.dygraph.Layer): -# def __init__(self, weight=None, reduction='mean'): -# if reduction not in ['sum', 'mean', 'none']: -# raise ValueError( -# "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but " -# "received %s, which is not allowed." % reduction) - -# super(BCEWithLogitsLoss, self).__init__() -# # self.weight = weight -# # self.reduction = reduction -# self.bce_loss = paddle.nn.BCELoss(weight, reduction) - -# def forward(self, input, label): -# input = paddle.nn.functional.sigmoid(input, True) -# return self.bce_loss(input, label) +class _SpectralNorm(paddle.nn.SpectralNorm): + def __init__(self, + weight_shape, + dim=0, + power_iters=1, + eps=1e-12, + dtype='float32'): + super(_SpectralNorm, self).__init__(weight_shape, dim, power_iters, eps, dtype) + + def forward(self, weight): + paddle.fluid.data_feeder.check_variable_and_dtype(weight, "weight", ['float32', 'float64'], + 'SpectralNorm') + inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} + out = self._helper.create_variable_for_type_inference(self._dtype) + _power_iters = self._power_iters if self.training else 0 + self._helper.append_op( + type="spectral_norm", + inputs=inputs, + outputs={"Out": out, }, + attrs={ + "dim": self._dim, + "power_iters": _power_iters, + "eps": self._eps, + }) + + return out + + +class Spectralnorm(paddle.nn.Layer): + + def __init__(self, + layer, + dim=0, + power_iters=1, + eps=1e-12, + dtype='float32'): + super(Spectralnorm, self).__init__() + self.spectral_norm = _SpectralNorm(layer.weight.shape, dim, power_iters, eps, dtype) + self.dim = dim + self.power_iters = power_iters + self.eps = eps + self.layer = layer + weight = layer._parameters['weight'] + del layer._parameters['weight'] + self.weight_orig = self.create_parameter(weight.shape, dtype=weight.dtype) + self.weight_orig.set_value(weight) + + + def forward(self, x): + weight = self.spectral_norm(self.weight_orig) + self.layer.weight = weight + out = self.layer(x) + return out def initial_type( diff --git a/ppgan/utils/video.py b/ppgan/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..056e547e13bfc187b679293d0ec9bc406100e6ef --- /dev/null +++ b/ppgan/utils/video.py @@ -0,0 +1,44 @@ +import os +import sys + +def video2frames(video_path, outpath, **kargs): + def _dict2str(kargs): + cmd_str = '' + for k, v in kargs.items(): + cmd_str += (' ' + str(k) + ' ' + str(v)) + return cmd_str + + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = video_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, vid_name) + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + cmd = ffmpeg + cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + _dict2str(kargs) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +def frames2video(frame_path, video_path, r): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', frame_path, ' -vcodec ', + ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', video_path + ] + cmd = ''.join(cmd) + + if os.system(cmd) != 0: + raise RuntimeError('ffmpeg process video: {} error'.format(video_path)) + + sys.stdout.flush() \ No newline at end of file diff --git a/ppgan/utils/visual.py b/ppgan/utils/visual.py index a50c59eb673a3a792e0712c94f898dee94e12e5d..f6c46e2793ce19fca3aff5ef8535f4d89e851cd2 100644 --- a/ppgan/utils/visual.py +++ b/ppgan/utils/visual.py @@ -2,7 +2,7 @@ import numpy as np from PIL import Image -def tensor2img(input_image, imtype=np.uint8): +def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8): """"Converts a Tensor array into a numpy image array. Parameters: @@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8): image_numpy = image_numpy[0] if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) - image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + image_numpy = image_numpy.clip(min_max[0], min_max[1]) + image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0]) + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) diff --git a/requirments.txt b/requirments.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa9cf06427ae88cdded67bd3d5853f32a1f5e34f --- /dev/null +++ b/requirments.txt @@ -0,0 +1 @@ +tqdm \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5adfc0d90f98274934ffe58c355ce8bb9f082d --- /dev/null +++ b/setup.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup +from io import open + +with open('requirments.txt', encoding="utf-8-sig") as f: + requirements = f.readlines() + + +def readme(): + with open('doc/doc_en/whl_en.md', encoding="utf-8-sig") as f: + README = f.read() + return README + + +setup( + name='ppgan', + packages=['ppgan'], + include_package_data=True, + entry_points={"console_scripts": ["paddlegan= paddlegan.paddlegan:main"]}, + version='0.1.0', + install_requires=requirements, + license='Apache License 2.0', + description='Awesome GAN toolkits based on PaddlePaddle', + url='https://github.com/PaddlePaddle/PaddleGAN', + download_url='https://github.com/PaddlePaddle/PaddleGAN.git', + keywords=[ + 'gan paddlegan' + ], + classifiers=[ + 'Intended Audience :: Developers', 'Operating System :: OS Independent', + 'Natural Language :: Chinese (Simplified)', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Utilities' + ], ) \ No newline at end of file