diff --git a/applications/EDVR/configs/edvr_L.yaml b/applications/EDVR/configs/edvr_L.yaml index 41b6f264356b114aef0887f6e7d72fb1e35340c2..75ab1371b6cb94f62282183fca2ff7f8be91a230 100644 --- a/applications/EDVR/configs/edvr_L.yaml +++ b/applications/EDVR/configs/edvr_L.yaml @@ -11,18 +11,6 @@ MODEL: HR_in: False w_TSA: True #False -TEST: - scale: 4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 1 - file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic" - gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT" - use_flip: False - use_rot: False - INFER: scale: 4 crop_size: 256 @@ -31,6 +19,6 @@ INFER: number_frames: 5 batch_size: 1 file_root: "/workspace/color/input_frames" - gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT" + #gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT" use_flip: False use_rot: False diff --git a/applications/EDVR/inference_model.py b/applications/EDVR/inference_model.py index 9ed5d272ae5aa025c2e3d2c78fc15e890c064f04..697253a87ec4731c3b57540da20291d7b64b1024 100644 --- a/applications/EDVR/inference_model.py +++ b/applications/EDVR/inference_model.py @@ -28,7 +28,7 @@ import paddle.fluid as fluid from utils.config_utils import * import models from reader import get_reader -from metrics import get_metrics +#from metrics import get_metrics from utils.utility import check_cuda logging.root.handlers = [] diff --git a/applications/EDVR/metrics/__init__.py b/applications/EDVR/metrics/__init__.py deleted file mode 100644 index 0d1df762bdf3d3b920fc1e00d15a3a2ecdcdbe55..0000000000000000000000000000000000000000 --- a/applications/EDVR/metrics/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .metrics_util import get_metrics diff --git a/applications/EDVR/metrics/edvr_metrics/__init__.py b/applications/EDVR/metrics/edvr_metrics/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/applications/EDVR/metrics/edvr_metrics/edvr_metrics.py b/applications/EDVR/metrics/edvr_metrics/edvr_metrics.py deleted file mode 100644 index 6f1a6fff59bfb6a61e29ebb1465853ee381604d4..0000000000000000000000000000000000000000 --- a/applications/EDVR/metrics/edvr_metrics/edvr_metrics.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and - -import numpy as np -import datetime -import logging -import json -import os -import cv2 -import math - -logger = logging.getLogger(__name__) - -class MetricsCalculator(): - def __init__( - self, - name='EDVR', - mode='train'): - self.name = name - self.mode = mode # 'train', 'valid', 'test', 'infer' - self.reset() - self.total_frames = 9002 #100 - self.bolder_frames = 2 - - def reset(self): - logger.info('Resetting {} metrics...'.format(self.mode)) - if (self.mode == 'train') or (self.mode == 'valid'): - self.aggr_loss = 0.0 - elif (self.mode == 'test') or (self.mode == 'infer'): - self.result_dict = dict() - - def calculate_and_logout(self, fetch_list, info): - pass - - def accumulate(self, fetch_list): - loss = fetch_list[0] - pred = fetch_list[1] - gt = fetch_list[2] - videoinfo = fetch_list[-1] - print('videoinfo: ', videoinfo) - videonames = [item[0] for item in videoinfo] - framenames = [item[1] for item in videoinfo] - for i in range(len(pred)): - pred_i = pred[i] - gt_i = gt[i] - videoname_i = videonames[i] - framename_i = framenames[i] - if videoname_i not in self.result_dict.keys(): - self.result_dict[videoname_i] = {} - if framename_i in self.result_dict[videoname_i].keys(): - logger.info("frame {} already processed in video {}, please check it".format(framename_i, videoname_i)) - raise - is_bolder = (int(framename_i) > (self.total_frames - self.bolder_frames - 1) - or int(framename_i) < self.bolder_frames) - psnr_i = get_psnr(pred_i, gt_i) - img_i = get_img(pred_i) - self.result_dict[videoname_i][framename_i] = [is_bolder, psnr_i] - is_save = True - if is_save and (i == len(pred) - 1): - save_img(img_i, framename_i) - logger.info("video {}, frame {}, bolder {}, psnr = {}".format(videoname_i, framename_i, is_bolder, psnr_i)) - - - def finalize_metrics(self, savedir): - avg_psnr = 0. - avg_psnr_center = 0. - avg_psnr_bolder = 0. - center_num = 0. - bolder_num = 0. - for videoname in self.result_dict.keys(): - videoresult = self.result_dict[videoname] - framelist = list(videoresult.keys()) - video_psnr_center = 0. - video_psnr_bolder = 0. - video_center_num = 0. - video_bolder_num = 0. - for frame in framelist: - frameresult = videoresult[frame] - is_bolder = frameresult[0] - psnr = frameresult[1] - if is_bolder: - video_bolder_num += 1 - video_psnr_bolder += psnr - else: - video_center_num += 1 - video_psnr_center += psnr - video_num = video_bolder_num + video_center_num - video_psnr = video_psnr_center + video_psnr_bolder - avg_psnr_bolder += video_psnr_bolder - avg_psnr_center += video_psnr_center - bolder_num += video_bolder_num - center_num += video_center_num - logger.info("video {}, total frame num/psnr {}/{}, center num/psnr {}/{}, bolder num/psnr {}/{}".format( - videoname, video_num, video_psnr/video_num, - video_center_num, video_psnr_center/video_center_num, - video_bolder_num, video_psnr_bolder/video_bolder_num)) - avg_psnr = avg_psnr_bolder + avg_psnr_center - total_num = bolder_num + center_num - avg_psnr = avg_psnr / total_num - avg_psnr_center = avg_psnr_center / center_num - avg_psnr_bolder = avg_psnr_bolder / bolder_num - logger.info("Average psnr {}, center {}, bolder {}".format(avg_psnr, avg_psnr_center, avg_psnr_bolder)) - - -def get_psnr(pred, gt): - # pred and gt have range [0, 1] - pred = pred.squeeze().astype(np.float64) - pred = pred * 255. - pred = pred.round() - gt = gt.squeeze().astype(np.float64) - gt = gt * 255. - gt = gt.round() - mse = np.mean((pred - gt)**2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) - - -def get_img(pred): - print('pred shape', pred.shape) - pred = pred.squeeze() - pred = np.clip(pred, a_min=0., a_max=1.0) - pred = pred * 255 - pred = pred.round() - pred = pred.astype('uint8') - pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc - pred = pred[:, :, ::-1] # rgb -> bgr - return pred - -def save_img(img, framename): - dirname = './demo/resultpng' - filename = os.path.join(dirname, framename+'.png') - cv2.imwrite(filename, img) - - diff --git a/applications/EDVR/metrics/metrics_util.py b/applications/EDVR/metrics/metrics_util.py deleted file mode 100644 index 611beca7905398f71a05098a813b01ea813254ec..0000000000000000000000000000000000000000 --- a/applications/EDVR/metrics/metrics_util.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -from __future__ import absolute_import -from __future__ import unicode_literals -from __future__ import print_function -from __future__ import division - -import logging - -import numpy as np -import json -from metrics.edvr_metrics import edvr_metrics as edvr_metrics - -logger = logging.getLogger(__name__) - - -class Metrics(object): - def __init__(self, name, mode, metrics_args): - """Not implemented""" - pass - - def calculate_and_log_out(self, fetch_list, info=''): - """Not implemented""" - pass - - def accumulate(self, fetch_list, info=''): - """Not implemented""" - pass - - def finalize_and_log_out(self, info='', savedir='./'): - """Not implemented""" - pass - - def reset(self): - """Not implemented""" - pass - - -class EDVRMetrics(Metrics): - def __init__(self, name, mode, cfg): - self.name = name - self.mode = mode - args = {} - args['mode'] = mode - args['name'] = name - self.calculator = edvr_metrics.MetricsCalculator(**args) - - def calculate_and_log_out(self, fetch_list, info=''): - if (self.mode == 'train') or (self.mode == 'valid'): - loss = np.array(fetch_list[0]) - logger.info(info + '\tLoss = {}'.format('%.04f' % np.mean(loss))) - elif self.mode == 'test': - pass - - def accumulate(self, fetch_list): - self.calculator.accumulate(fetch_list) - - def finalize_and_log_out(self, info='', savedir='./'): - self.calculator.finalize_metrics(savedir) - - def reset(self): - self.calculator.reset() - - -class MetricsZoo(object): - def __init__(self): - self.metrics_zoo = {} - - def regist(self, name, metrics): - assert metrics.__base__ == Metrics, "Unknow model type {}".format( - type(metrics)) - self.metrics_zoo[name] = metrics - - def get(self, name, mode, cfg): - for k, v in self.metrics_zoo.items(): - if k == name: - return v(name, mode, cfg) - raise MetricsNotFoundError(name, self.metrics_zoo.keys()) - - -# singleton metrics_zoo -metrics_zoo = MetricsZoo() - - -def regist_metrics(name, metrics): - metrics_zoo.regist(name, metrics) - - -def get_metrics(name, mode, cfg): - return metrics_zoo.get(name, mode, cfg) - - -# sort by alphabet -regist_metrics("EDVR", EDVRMetrics) diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index 154d4c48d339f3fab2d1f68e82fddc113da749f7..d3e97f1473054c82fb9bb12b9363b4b125fd09f9 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -29,7 +29,7 @@ import cv2 from utils.config_utils import * import models from reader import get_reader -from metrics import get_metrics +#from metrics import get_metrics from utils.utility import check_cuda from utils.utility import check_version @@ -56,12 +56,6 @@ def parse_args(): type=ast.literal_eval, default=True, help='default use gpu.') - # parser.add_argument( - # '--weights', - # type=str, - # default=None, - # help='weight path, None to automatically download weights provided by Paddle.' - # ) parser.add_argument( '--batch_size', type=int,