diff --git a/configs/basicvsr++_vimeo90k_BD.yaml b/configs/basicvsr++_vimeo90k_BD.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b1c8691ef0e844cd730dd70a8d05f7ef57e1f14 --- /dev/null +++ b/configs/basicvsr++_vimeo90k_BD.yaml @@ -0,0 +1,122 @@ +total_iters: 600000 +output_dir: output_dir +find_unused_parameters: True +checkpoints_dir: checkpoints +# tensor range for function tensor2img +min_max: + (0., 1.) + +model: + name: BasicVSRModel + fix_iter: 5000 + lr_mult: 0.25 + generator: + name: BasicVSRPlusPlus + mid_channels: 64 + num_blocks: 7 + is_low_res_input: True + pixel_criterion: + name: CharbonnierLoss + reduction: mean + +dataset: + train: + name: RepeatDataset + times: 1000 + num_workers: 4 + batch_size: 1 #4 gpus + dataset: + name: VSRVimeo90KDataset + # mode: train + lq_folder: data/vimeo90k/vimeo_septuplet_BD_matlabLRx4/sequences + gt_folder: data/vimeo90k/vimeo_septuplet/sequences + ann_file: data/vimeo90k/vimeo_septuplet/sep_trainlist.txt + preprocess: + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: SRPairedRandomCrop + gt_patch_size: 256 + scale: 4 + keys: [image, image] + - name: PairedRandomHorizontalFlip + keys: [image, image] + - name: PairedRandomVerticalFlip + keys: [image, image] + - name: PairedRandomTransposeHW + keys: [image, image] + - name: TransposeSequence + keys: [image, image] + - name: MirrorVideoSequence + - name: NormalizeSequence + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + + test: + name: VSRFolderDataset + # for udm10 dataset + # lq_folder: data/udm10/BDx4 + # gt_folder: data/udm10/GT + lq_folder: data/Vid4/BDx4 + gt_folder: data/Vid4/GT + preprocess: + - name: GetNeighboringFramesIdx + interval_list: [1] + # for udm10 dataset + # filename_tmpl: '{:04d}.png' + filename_tmpl: '{:08d}.png' + - name: ReadImageSequence + key: lq + - name: ReadImageSequence + key: gt + - name: Transforms + input_keys: [lq, gt] + pipeline: + - name: TransposeSequence + keys: [image, image] + - name: NormalizeSequence + mean: [0., .0, 0.] + std: [255., 255., 255.] + keys: [image, image] + +lr_scheduler: + name: CosineAnnealingRestartLR + learning_rate: !!float 1e-4 + periods: [600000] + restart_weights: [1] + eta_min: !!float 1e-7 + +optimizer: + name: Adam + # add parameters of net_name to optim + # name should in self.nets + net_names: + - generator + beta1: 0.9 + beta2: 0.99 + +validate: + interval: 5000 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + name: PSNR + crop_border: 0 + test_y_channel: true + ssim: + name: SSIM + crop_border: 0 + test_y_channel: true + +log_config: + interval: 10 + visiual_interval: 500 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index b419b4b21947c0982219376bae8abecad8585660..e402927b992fb90bd0e0466a6f9a0c1078eec19a 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -26,4 +26,6 @@ from .firstorder_dataset import FirstOrderDataset from .lapstyle_dataset import LapStyleDataset from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset from .mpr_dataset import MPRTrain, MPRVal, MPRTest +from .vsr_vimeo90k_dataset import VSRVimeo90KDataset +from .vsr_folder_dataset import VSRFolderDataset from .photopen_dataset import PhotoPenDataset diff --git a/ppgan/datasets/preprocess/__init__.py b/ppgan/datasets/preprocess/__init__.py index 883dce15d2eb48a5e00bd195fd02122809da917a..1712224e761b6f3c8b291ff6695d38fa57eca561 100644 --- a/ppgan/datasets/preprocess/__init__.py +++ b/ppgan/datasets/preprocess/__init__.py @@ -1,6 +1,8 @@ -from .io import LoadImageFromFile +from .io import LoadImageFromFile, ReadImageSequence, GetNeighboringFramesIdx from .transforms import (PairedRandomCrop, PairedRandomHorizontalFlip, PairedRandomVerticalFlip, PairedRandomTransposeHW, - SRPairedRandomCrop, SplitPairedImage, SRNoise) + SRPairedRandomCrop, SplitPairedImage, SRNoise, + NormalizeSequence, MirrorVideoSequence, + TransposeSequence) from .builder import build_preprocess diff --git a/ppgan/datasets/preprocess/io.py b/ppgan/datasets/preprocess/io.py index a7873e2eb8f27dbf9605dcc4649790ff85700371..d8ce34e4976c6d754d47c5471db7df346900b736 100644 --- a/ppgan/datasets/preprocess/io.py +++ b/ppgan/datasets/preprocess/io.py @@ -1,6 +1,7 @@ # code was reference to mmcv +import os import cv2 - +import numpy as np from .builder import PREPROCESS @@ -9,12 +10,12 @@ class LoadImageFromFile(object): """Load image from file. Args: - key (str): Keys in results to find corresponding path. Default: 'image'. + key (str): Keys in datas to find corresponding path. Default: 'image'. flag (str): Loading flag for images. Default: -1. to_rgb (str): Convert img to 'rgb' format. Default: True. backend (str): io backend where images are store. Default: None. save_original_img (bool): If True, maintain a copy of the image in - `results` dict with name of `f'ori_{key}'`. Default: False. + `datas` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __init__(self, @@ -31,28 +32,150 @@ class LoadImageFromFile(object): self.save_original_img = save_original_img self.kwargs = kwargs - def __call__(self, results): + def __call__(self, datas): """Call function. Args: - results (dict): A dict containing the necessary information and + datas (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ - filepath = str(results[f'{self.key}_path']) + filepath = str(datas[f'{self.key}_path']) #TODO: use file client to manage io backend # such as opencv, pil, imdb img = cv2.imread(filepath, self.flag) if self.to_rgb: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - results[self.key] = img - results[f'{self.key}_path'] = filepath - results[f'{self.key}_ori_shape'] = img.shape + datas[self.key] = img + datas[f'{self.key}_path'] = filepath + datas[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: - results[f'ori_{self.key}'] = img.copy() + datas[f'ori_{self.key}'] = img.copy() + + return datas + + +@PREPROCESS.register() +class ReadImageSequence(LoadImageFromFile): + """Read image sequence. + + It accepts a list of path and read each frame from each path. A list + of frames will be returned. + + Args: + key (str): Keys in datas to find corresponding path. Default: 'gt'. + flag (str): Loading flag for images. Default: 'color'. + to_rgb (str): Convert img to 'rgb' format. Default: True. + save_original_img (bool): If True, maintain a copy of the image in + `datas` dict with name of `f'ori_{key}'`. Default: False. + kwargs (dict): Args for file client. + """ + def __call__(self, datas): + """Call function. + + Args: + datas (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + + filepaths = datas[f'{self.key}_path'] + if not isinstance(filepaths, list): + raise TypeError( + f'filepath should be list, but got {type(filepaths)}') + + filepaths = [str(v) for v in filepaths] + + imgs = [] + shapes = [] + if self.save_original_img: + ori_imgs = [] + for filepath in filepaths: + img = cv2.imread(filepath, self.flag) + + if self.to_rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + imgs.append(img) + shapes.append(img.shape) + if self.save_original_img: + ori_imgs.append(img.copy()) + + datas[self.key] = imgs + datas[f'{self.key}_path'] = filepaths + datas[f'{self.key}_ori_shape'] = shapes + if self.save_original_img: + datas[f'ori_{self.key}'] = ori_imgs + + return datas + + +@PREPROCESS.register() +class GetNeighboringFramesIdx: + """Get neighboring frame indices for a video. It also performs temporal + augmention with random interval. + + Args: + interval_list (list[int]): Interval list for temporal augmentation. + It will randomly pick an interval from interval_list and sample + frame index with the interval. + start_idx (int): The index corresponds to the first frame in the + sequence. Default: 0. + filename_tmpl (str): Template for file name. Default: '{:08d}.png'. + """ + def __init__(self, interval_list, start_idx=0, filename_tmpl='{:08d}.png'): + self.interval_list = interval_list + self.filename_tmpl = filename_tmpl + self.start_idx = start_idx + + def __call__(self, datas): + """Call function. + + Args: + datas (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + + clip_name = datas['key'] + interval = np.random.choice(self.interval_list) + + self.sequence_length = datas['sequence_length'] + num_frames = datas.get('num_frames', self.sequence_length) + + if self.sequence_length - num_frames * interval < 0: + raise ValueError('The input sequence is not long enough to ' + 'support the current choice of [interval] or ' + '[num_frames].') + start_frame_idx = np.random.randint( + 0, self.sequence_length - num_frames * interval + 1) + end_frame_idx = start_frame_idx + num_frames * interval + neighbor_list = list(range(start_frame_idx, end_frame_idx, interval)) + neighbor_list = [v + self.start_idx for v in neighbor_list] + + lq_path_root = datas['lq_path'] + gt_path_root = datas['gt_path'] + + lq_path = [ + os.path.join(lq_path_root, clip_name, self.filename_tmpl.format(v)) + for v in neighbor_list + ] + gt_path = [ + os.path.join(gt_path_root, clip_name, self.filename_tmpl.format(v)) + for v in neighbor_list + ] + + datas['lq_path'] = lq_path + datas['gt_path'] = gt_path + datas['interval'] = interval - return results + return datas diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py index 6ba55361ecc02ce5b0799d2158e654989dca73e3..3064bb3984af4e5e48b251fc55a4134cb0ff7f8a 100644 --- a/ppgan/datasets/preprocess/transforms.py +++ b/ppgan/datasets/preprocess/transforms.py @@ -55,6 +55,7 @@ class Transforms(): def __call__(self, datas): data = [] + for k in self.input_keys: data.append(datas[k]) data = tuple(data) @@ -133,7 +134,10 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): def _apply_image(self, image): if self.params['flip']: - return F.hflip(image) + if isinstance(image, list): + image = [F.hflip(v) for v in image] + else: + return F.hflip(image) return image @@ -149,7 +153,10 @@ class PairedRandomVerticalFlip(T.RandomHorizontalFlip): def _apply_image(self, image): if self.params['flip']: - return F.hflip(image) + if isinstance(image, list): + image = [F.vflip(v) for v in image] + else: + return F.vflip(image) return image @@ -180,10 +187,108 @@ class PairedRandomTransposeHW(T.BaseTransform): def _apply_image(self, image): if self.params['transpose']: - image = image.transpose(1, 0, 2) + if isinstance(image, list): + image = [v.transpose(1, 0, 2) for v in image] + else: + image = image.transpose(1, 0, 2) return image +@TRANSFORMS.register() +class TransposeSequence(T.Transpose): + """Transpose input data or a video sequence to a target format. + For example, most transforms use HWC mode image, + while the Neural Network might use CHW mode input tensor. + output image will be an instance of numpy.ndarray. + + Args: + order (list|tuple, optional): Target order of input data. Default: (2, 0, 1). + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + + Examples: + + .. code-block:: python + + import numpy as np + from PIL import Image + + transform = TransposeSequence() + + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) + + fake_img_seq = [fake_img, fake_img, fake_img] + fake_img_seq = transform(fake_img_seq) + + """ + def _apply_image(self, img): + if isinstance(img, list): + imgs = [] + for im in img: + if F._is_tensor_image(im): + return im.transpose(self.order) + + if F._is_pil_image(im): + im = np.asarray(im) + + if len(im.shape) == 2: + im = im[..., np.newaxis] + imgs.append(im.transpose(self.order)) + return imgs + else: + if F._is_tensor_image(img): + return img.transpose(self.order) + + if F._is_pil_image(img): + img = np.asarray(img) + + if len(img.shape) == 2: + img = img[..., np.newaxis] + return img.transpose(self.order) + + +@TRANSFORMS.register() +class NormalizeSequence(T.Normalize): + """Normalize the input data with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, + this transform will normalize each channel of the input data. + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (int|float|list|tuple): Sequence of means for each channel. + std (int|float|list|tuple): Sequence of standard deviations for each channel. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + to_rgb (bool, optional): Whether to convert to rgb. Default: False. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + + Examples: + + .. code-block:: python + + import numpy as np + from PIL import Image + + normalize_seq = NormalizeSequence(mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + data_format='HWC') + + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) + fake_img_seq = [fake_img, fake_img, fake_img] + fake_img_seq = normalize_seq(fake_img_seq) + + """ + def _apply_image(self, img): + if isinstance(img, list): + imgs = [ + F.normalize(v, self.mean, self.std, self.data_format, + self.to_rgb) for v in img + ] + return np.stack(imgs, axis=0).astype('float32') + + return F.normalize(img, self.mean, self.std, self.data_format, + self.to_rgb) + + @TRANSFORMS.register() class SRPairedRandomCrop(T.BaseTransform): """Super resolution random crop. @@ -204,15 +309,19 @@ class SRPairedRandomCrop(T.BaseTransform): self.scale_list = scale_list def __call__(self, inputs): - """inputs must be (lq_img, gt_img)""" + """inputs must be (lq_img or list[lq_img], gt_img or list[gt_img])""" scale = self.scale lq_patch_size = self.gt_patch_size // scale lq = inputs[0] gt = inputs[1] - h_lq, w_lq, _ = lq.shape - h_gt, w_gt, _ = gt.shape + if isinstance(lq, list): + h_lq, w_lq, _ = lq[0].shape + h_gt, w_gt, _ = gt[0].shape + else: + h_lq, w_lq, _ = lq.shape + h_gt, w_gt, _ = gt.shape if h_gt != h_lq * scale or w_gt != w_lq * scale: raise ValueError('scale size not match') @@ -222,18 +331,30 @@ class SRPairedRandomCrop(T.BaseTransform): # randomly choose top and left coordinates for lq patch top = random.randint(0, h_lq - lq_patch_size) left = random.randint(0, w_lq - lq_patch_size) - # crop lq patch - lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...] - # crop corresponding gt patch - top_gt, left_gt = int(top * scale), int(left * scale) - gt = gt[top_gt:top_gt + self.gt_patch_size, - left_gt:left_gt + self.gt_patch_size, ...] - - if self.scale_list and self.scale == 4: - lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2), - 'bicubic') - outputs = (lq, lqx2, gt) - return outputs + + if isinstance(lq, list): + lq = [ + v[top:top + lq_patch_size, left:left + lq_patch_size, ...] + for v in lq + ] + top_gt, left_gt = int(top * scale), int(left * scale) + gt = [ + v[top_gt:top_gt + self.gt_patch_size, + left_gt:left_gt + self.gt_patch_size, ...] for v in gt + ] + else: + # crop lq patch + lq = lq[top:top + lq_patch_size, left:left + lq_patch_size, ...] + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + gt = gt[top_gt:top_gt + self.gt_patch_size, + left_gt:left_gt + self.gt_patch_size, ...] + + if self.scale_list and self.scale == 4: + lqx2 = F.resize(gt, (lq_patch_size * 2, lq_patch_size * 2), + 'bicubic') + outputs = (lq, lqx2, gt) + return outputs outputs = (lq, gt) return outputs @@ -411,3 +532,36 @@ class PairedColorJitter(T.BaseTransform): for f in self.params: img = f(img) return img + + +@TRANSFORMS.register() +class MirrorVideoSequence: + """Double a short video sequences by mirroring the sequences + + Example: + Given a sequence with N frames (x1, ..., xN), extend the + sequence to (x1, ..., xN, xN, ..., x1). + + Args: + keys (list[str]): The frame lists to be extended. + """ + def __init__(self, keys=None): + self.keys = keys + + def __call__(self, datas): + """Call function. + + Args: + datas (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + lrs, hrs = datas + assert isinstance(lrs, list) and isinstance(hrs, list) + + lrs = lrs + lrs[::-1] + hrs = hrs + hrs[::-1] + + return (lrs, hrs) diff --git a/ppgan/datasets/vsr_folder_dataset.py b/ppgan/datasets/vsr_folder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0001fbe37a553642da9552d1ae9a45418af62ad6 --- /dev/null +++ b/ppgan/datasets/vsr_folder_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import glob +import random +import logging +import numpy as np +from paddle.io import Dataset + +from .base_sr_dataset import BaseDataset +from .builder import DATASETS + +logger = logging.getLogger(__name__) + + +@DATASETS.register() +class VSRFolderDataset(BaseDataset): + """Video super-resolution for folder format. + + Args: + lq_folder (str): Path to a low quality image folder. + gt_folder (str): Path to a ground truth image folder. + ann_file (str): Path to the annotation file. + preprocess (list[dict|callable]): A list functions of data transformations. + num_frames (int): Number of frames of each input clip. + times (int): Repeat times of datset length. + """ + def __init__(self, + lq_folder, + gt_folder, + preprocess, + num_frames=None, + times=1): + super().__init__(preprocess) + + self.lq_folder = str(lq_folder) + self.gt_folder = str(gt_folder) + self.num_frames = num_frames + self.times = times + + self.data_infos = self.prepare_data_infos() + + def prepare_data_infos(self): + + sequences = sorted(glob.glob(os.path.join(self.lq_folder, '*'))) + + data_infos = [] + for sequence in sequences: + sequence_length = len(glob.glob(os.path.join(sequence, '*.png'))) + if self.num_frames is None: + num_frames = sequence_length + else: + num_frames = self.num_frames + data_infos.append( + dict(lq_path=self.lq_folder, + gt_path=self.gt_folder, + key=sequence.replace(f'{self.lq_folder}/', ''), + num_frames=num_frames, + sequence_length=sequence_length)) + return data_infos diff --git a/ppgan/datasets/vsr_vimeo90k_dataset.py b/ppgan/datasets/vsr_vimeo90k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ee141f34e7163abe30277189f2f4940b176b21e3 --- /dev/null +++ b/ppgan/datasets/vsr_vimeo90k_dataset.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import glob +import random +import logging +import numpy as np +from paddle.io import Dataset + +from .base_sr_dataset import BaseDataset +from .builder import DATASETS + + +@DATASETS.register() +class VSRVimeo90KDataset(BaseDataset): + """Vimeo90K dataset for video super resolution for recurrent networks. + + The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) + frames. Then it applies specified transforms and finally returns a dict + containing paired data and other information. + + It reads Vimeo90K keys from the txt file. Each line contains video frame folder + + Examples: + + 00001/0233 + 00001/0234 + + Args: + lq_folder (str): Path to a low quality image folder. + gt_folder (str): Path to a ground truth image folder. + ann_file (str): Path to the annotation file. + preprocess (list[dict|callable]): A list functions of data transformations. + """ + def __init__(self, lq_folder, gt_folder, ann_file, preprocess): + super().__init__(preprocess) + + self.lq_folder = str(lq_folder) + self.gt_folder = str(gt_folder) + self.ann_file = str(ann_file) + + self.data_infos = self.prepare_data_infos() + + def prepare_data_infos(self): + + with open(self.ann_file, 'r') as fin: + keys = [line.strip() for line in fin] + + data_infos = [] + for key in keys: + lq_paths = sorted( + glob.glob(os.path.join(self.lq_folder, key, '*.png'))) + gt_paths = sorted( + glob.glob(os.path.join(self.gt_folder, key, '*.png'))) + + data_infos.append(dict(lq_path=lq_paths, gt_path=gt_paths, key=key)) + + return data_infos diff --git a/ppgan/metrics/psnr_ssim.py b/ppgan/metrics/psnr_ssim.py index 72702de0423b322dacfa676f4b66f534580c781b..e6885fb044584b9ae4dc39b2f07ec5a674bdb397 100644 --- a/ppgan/metrics/psnr_ssim.py +++ b/ppgan/metrics/psnr_ssim.py @@ -30,17 +30,26 @@ class PSNR(paddle.metric.Metric): def reset(self): self.results = [] - def update(self, preds, gts): + def update(self, preds, gts, is_seq=False): if not isinstance(preds, (list, tuple)): preds = [preds] if not isinstance(gts, (list, tuple)): gts = [gts] + if is_seq: + single_seq = [] + for pred, gt in zip(preds, gts): value = calculate_psnr(pred, gt, self.crop_border, self.input_order, self.test_y_channel) - self.results.append(value) + if is_seq: + single_seq.append(value) + else: + self.results.append(value) + + if is_seq: + self.results.append(np.mean(single_seq)) def accumulate(self): if paddle.distributed.get_world_size() > 1: @@ -59,17 +68,26 @@ class PSNR(paddle.metric.Metric): @METRICS.register() class SSIM(PSNR): - def update(self, preds, gts): + def update(self, preds, gts, is_seq=False): if not isinstance(preds, (list, tuple)): preds = [preds] if not isinstance(gts, (list, tuple)): gts = [gts] + if is_seq: + single_seq = [] + for pred, gt in zip(preds, gts): value = calculate_ssim(pred, gt, self.crop_border, self.input_order, self.test_y_channel) - self.results.append(value) + if is_seq: + single_seq.append(value) + else: + self.results.append(value) + + if is_seq: + self.results.append(np.mean(single_seq)) def name(self): return 'SSIM' diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index ede254ec35e437f9e680988c14b06a6217d387d0..ae4ecd2bdc7ddd077640c5752fd7dabf14436ba7 100755 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -25,7 +25,7 @@ from ..utils.visual import tensor2img class BaseModel(ABC): - """This class is an abstract base class (ABC) for models. + r"""This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class. -- : unpack data from dataset and apply preprocessing. diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py index b2db12879ae60cc076692cf6eb7466f61d21aca8..54a9b5454842339259143991514cb392511c0014 100644 --- a/ppgan/models/basicvsr_model.py +++ b/ppgan/models/basicvsr_model.py @@ -103,7 +103,7 @@ class BasicVSRModel(BaseSRModel): if metrics is not None: for metric in metrics.values(): - metric.update(out_img, gt_img) + metric.update(out_img, gt_img, is_seq=True) def init_basicvsr_weight(net):