From 8c46c4439b742f5dba8e182d50b6407c580954c3 Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Wed, 12 Aug 2020 06:22:48 +0000 Subject: [PATCH] delete inference irralated code --- applications/EDVR/configs/edvr_L.yaml | 37 -- applications/EDVR/models/__init__.py | 27 -- applications/EDVR/models/edvr/README.md | 112 ------ applications/EDVR/reader/edvr_reader.py.bak | 395 ------------------- applications/EDVR/reader/edvr_reader.pybk | 398 -------------------- applications/EDVR/run.sh | 57 +-- 6 files changed, 4 insertions(+), 1022 deletions(-) delete mode 100644 applications/EDVR/models/edvr/README.md delete mode 100644 applications/EDVR/reader/edvr_reader.py.bak delete mode 100644 applications/EDVR/reader/edvr_reader.pybk diff --git a/applications/EDVR/configs/edvr_L.yaml b/applications/EDVR/configs/edvr_L.yaml index 2733e74..41b6f26 100644 --- a/applications/EDVR/configs/edvr_L.yaml +++ b/applications/EDVR/configs/edvr_L.yaml @@ -11,43 +11,6 @@ MODEL: HR_in: False w_TSA: True #False -TRAIN: - epoch: 45 - use_gpu: True - num_gpus: 4 #8 - scale: 4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 32 - file_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp_bicubic/X4" - gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp" - use_flip: True - use_rot: True - base_learning_rate: 0.0004 - l2_weight_decay: 0.0 - TSA_only: False - T_periods: [50000, 100000, 150000, 150000, 150000] # for cosine annealing restart - restarts: [50000, 150000, 300000, 450000] # for cosine annealing restart - weights: [1, 1, 1, 1] # for cosine annealing restart - eta_min: 1e-7 # for cosine annealing restart - num_reader_threads: 8 - buf_size: 1024 - fix_random_seed: False - -VALID: - scale: 4 - crop_size: 256 - interval_list: [1] - random_reverse: False - number_frames: 5 - batch_size: 32 #256 - file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic" - gt_root: "/workspace/video_test/video/data/dataset/edvr/edvr/REDS4/GT" - use_flip: False - use_rot: False - TEST: scale: 4 crop_size: 256 diff --git a/applications/EDVR/models/__init__.py b/applications/EDVR/models/__init__.py index 67b12a9..60d90ea 100644 --- a/applications/EDVR/models/__init__.py +++ b/applications/EDVR/models/__init__.py @@ -1,31 +1,4 @@ from .model import regist_model, get_model -#from .attention_cluster import AttentionCluster -#from .attention_lstm import AttentionLSTM -#from .nextvlad import NEXTVLAD -#from .nonlocal_model import NonLocal -#from .tsm import TSM -#from .tsn import TSN -#from .stnet import STNET -#from .ctcn import CTCN -#from .bmn import BMN -#from .bsn import BsnTem -#from .bsn import BsnPem -#from .ets import ETS -#from .tall import TALL from .edvr import EDVR -# regist models, sort by alphabet -#regist_model("AttentionCluster", AttentionCluster) -#regist_model("AttentionLSTM", AttentionLSTM) -#regist_model("NEXTVLAD", NEXTVLAD) -#regist_model('NONLOCAL', NonLocal) -#regist_model("TSM", TSM) -#regist_model("TSN", TSN) -#regist_model("STNET", STNET) -#regist_model("CTCN", CTCN) -#regist_model("BMN", BMN) -#regist_model("BsnTem", BsnTem) -#regist_model("BsnPem", BsnPem) -#regist_model("ETS", ETS) -#regist_model("TALL", TALL) regist_model("EDVR", EDVR) diff --git a/applications/EDVR/models/edvr/README.md b/applications/EDVR/models/edvr/README.md deleted file mode 100644 index 056b1e5..0000000 --- a/applications/EDVR/models/edvr/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# TSN 视频分类模型 - ---- -## 内容 - -- [模型简介](#模型简介) -- [数据准备](#数据准备) -- [模型训练](#模型训练) -- [模型评估](#模型评估) -- [模型推断](#模型推断) -- [参考论文](#参考论文) - - -## 模型简介 - -Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。 - -详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859) - -## 数据准备 - -TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](../../data/dataset/README.md) - -## 模型训练 - -数据准备完毕后,可以通过如下两种方式启动训练: - - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - export FLAGS_fast_eager_deletion_mode=1 - export FLAGS_eager_delete_tensor_gb=0.0 - export FLAGS_fraction_of_gpu_memory_to_use=0.98 - python train.py --model_name=TSN \ - --config=./configs/tsn.yaml \ - --log_interval=10 \ - --valid_interval=1 \ - --use_gpu=True \ - --save_dir=./data/checkpoints \ - --fix_random_seed=False \ - --pretrain=$PATH_TO_PRETRAIN_MODEL - - bash run.sh train TSN ./configs/tsn.yaml - -- 从头开始训练,需要加载在ImageNet上训练的ResNet50权重作为初始化参数,请下载此[模型参数](https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz)并解压,将上面启动命令行或者run.sh脚本中的`pretrain`参数设置为解压之后的模型参数 -存放路径。如果没有手动下载并设置`pretrain`参数,则程序会自动下载并将参数保存在~/.paddle/weights/ResNet50\_pretrained目录下面 - -- 可下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)通过`--resume`指定权重存 -放路径进行finetune等开发 - -**数据读取器说明:** 模型读取Kinetics-400数据集中的`mp4`数据,每条数据抽取`seg_num`段,每段抽取1帧图像,对每帧图像做随机增强后,缩放至`target_size`。 - -**训练策略:** - -* 采用Momentum优化算法训练,momentum=0.9 -* 权重衰减系数为1e-4 -* 学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减 - -## 模型评估 - -可通过如下两种方式进行模型评估: - - python eval.py --model_name=TSN \ - --config=./configs/tsn.yaml \ - --log_interval=1 \ - --weights=$PATH_TO_WEIGHTS \ - --use_gpu=True - - bash run.sh eval TSN ./configs/tsn.yaml - -- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重 - -- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行评估 - -- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标 - -- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False - - -当取如下参数时,在Kinetics400的validation数据集下评估精度如下: - -| seg\_num | target\_size | Top-1 | -| :------: | :----------: | :----: | -| 3 | 224 | 0.66 | -| 7 | 224 | 0.67 | - -## 模型推断 - -可通过如下两种方式启动模型推断: - - python predict.py --model_name=TSN \ - --config=./configs/tsn.yaml \ - --log_interval=1 \ - --weights=$PATH_TO_WEIGHTS \ - --filelist=$FILELIST \ - --use_gpu=True \ - --video_path=$VIDEO_PATH - - bash run.sh predict TSN ./configs/tsn.yaml - -- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。 - -- 如果video\_path为'', 则忽略掉此参数。如果video\_path != '',则程序会对video\_path指定的视频文件进行预测,而忽略掉filelist的值,预测结果为此视频的分类概率。 - -- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行推断 - -- 模型推断结果以log的形式直接打印输出,可以看到测试样本的分类预测概率。 - -- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False - -## 参考论文 - -- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool - diff --git a/applications/EDVR/reader/edvr_reader.py.bak b/applications/EDVR/reader/edvr_reader.py.bak deleted file mode 100644 index 152ab19..0000000 --- a/applications/EDVR/reader/edvr_reader.py.bak +++ /dev/null @@ -1,395 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -import os -import sys -import cv2 -import math -import random -import multiprocessing -import functools -import numpy as np -import paddle -import cv2 -import logging -from .reader_utils import DataReader - -logger = logging.getLogger(__name__) -python_ver = sys.version_info - -random.seed(0) -np.random.seed(0) - -class EDVRReader(DataReader): - """ - Data reader for video super resolution task fit for EDVR model. - This is specified for REDS dataset. - """ - def __init__(self, name, mode, cfg): - super(EDVRReader, self).__init__(name, mode, cfg) - self.format = cfg.MODEL.format - self.crop_size = self.get_config_from_sec(mode, 'crop_size') - self.interval_list = self.get_config_from_sec(mode, 'interval_list') - self.random_reverse = self.get_config_from_sec(mode, 'random_reverse') - self.number_frames = self.get_config_from_sec(mode, 'number_frames') - # set batch size and file list - self.batch_size = cfg[mode.upper()]['batch_size'] - self.fileroot = cfg[mode.upper()]['file_root'] - self.use_flip = self.get_config_from_sec(mode, 'use_flip', False) - self.use_rot = self.get_config_from_sec(mode, 'use_rot', False) - - self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1) - self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024) - self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False) - - if self.mode != 'infer': - self.gtroot = self.get_config_from_sec(mode, 'gt_root') - self.scale = self.get_config_from_sec(mode, 'scale', 1) - self.LR_input = (self.scale > 1) - if self.fix_random_seed: - random.seed(0) - np.random.seed(0) - self.num_reader_threads = 1 - - def create_reader(self): - logger.info('initialize reader ... ') - self.filelist = [] - for video_name in os.listdir(self.fileroot): - if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']): - continue - for frame_name in os.listdir(os.path.join(self.fileroot, video_name)): - frame_idx = frame_name.split('.')[0] - video_frame_idx = video_name + '_' + frame_idx - # for each item in self.filelist is like '010_00000015', '260_00000090' - self.filelist.append(video_frame_idx) - if self.mode == 'test': - self.filelist.sort() - - if self.num_reader_threads == 1: - reader_func = make_reader - else: - reader_func = make_multi_reader - - - return reader_func(filelist = self.filelist, - num_threads = self.num_reader_threads, - batch_size = self.batch_size, - is_training = (self.mode == 'train'), - number_frames = self.number_frames, - interval_list = self.interval_list, - random_reverse = self.random_reverse, - gtroot = self.gtroot, - fileroot = self.fileroot, - LR_input = self.LR_input, - crop_size = self.crop_size, - scale = self.scale, - use_flip = self.use_flip, - use_rot = self.use_rot, - mode = self.mode) - - -def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot, mode='train'): - video_name = item.split('_')[0] - frame_name = item.split('_')[1] - if (mode == 'train') or (mode == 'valid'): - ngb_frames, name_b = get_neighbor_frames(frame_name, \ - number_frames = number_frames, \ - interval_list = interval_list, \ - random_reverse = random_reverse) - elif mode == 'test': - ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames) - else: - raise NotImplementedError('mode {} not implemented'.format(mode)) - frame_name = name_b - print('key2', ngb_frames, name_b) - img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True) - #print('gt_mean', np.mean(img_GT)) - frame_list = [] - for ngb_frm in ngb_frames: - ngb_name = "%08d"%ngb_frm - #img = read_img(os.path.join(fileroot, video_name, frame_name + '.png')) - img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png')) - frame_list.append(img) - #print('img_mean', np.mean(img)) - - H, W, C = frame_list[0].shape - # add random crop - if (mode == 'train') or (mode == 'valid'): - if LR_input: - LQ_size = crop_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - #print('rnd_h {}, rnd_w {}', rnd_h, rnd_w) - frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :] - else: - rnd_h = random.randint(0, max(0, H - crop_size)) - rnd_w = random.randint(0, max(0, W - crop_size)) - frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list] - img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] - - # add random flip and rotation - frame_list.append(img_GT) - if (mode == 'train') or (mode == 'valid'): - rlt = img_augment(frame_list, use_flip, use_rot) - else: - rlt = frame_list - frame_list = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(frame_list, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') - img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') - - return img_LQs, img_GT - -def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'): - """Generate an index list for reading N frames from a sequence of images - Args: - crt_i (int): current center index - max_n (int): max number of the sequence of images (calculated from 1) - N (int): reading N frames - padding (str): padding mode, one of replicate | reflection | new_info | circle - Example: crt_i = 0, N = 5 - replicate: [0, 0, 0, 1, 2] - reflection: [2, 1, 0, 1, 2] - new_info: [4, 3, 0, 1, 2] - circle: [3, 4, 0, 1, 2] - - Returns: - return_l (list [int]): a list of indexes - """ - max_n = max_n - 1 - n_pad = N // 2 - return_l = [] - - for i in range(crt_i - n_pad, crt_i + n_pad + 1): - if i < 0: - if padding == 'replicate': - add_idx = 0 - elif padding == 'reflection': - add_idx = -i - elif padding == 'new_info': - add_idx = (crt_i + n_pad) + (-i) - elif padding == 'circle': - add_idx = N + i - else: - raise ValueError('Wrong padding mode') - elif i > max_n: - if padding == 'replicate': - add_idx = max_n - elif padding == 'reflection': - add_idx = max_n * 2 - i - elif padding == 'new_info': - add_idx = (crt_i - n_pad) - (i - max_n) - elif padding == 'circle': - add_idx = i - N - else: - raise ValueError('Wrong padding mode') - else: - add_idx = i - return_l.append(add_idx) - name_b = '{:08d}'.format(crt_i) - return return_l, name_b - - -def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False): - center_frame_idx = int(frame_name) - half_N_frames = number_frames // 2 - #### determine the neighbor frames - interval = random.choice(interval_list) - if bordermode: - direction = 1 # 1: forward; 0: backward - if random_reverse and random.random() < 0.5: - direction = random.choice([0, 1]) - if center_frame_idx + interval * (number_frames - 1) > max_frame: - direction = 0 - elif center_frame_idx - interval * (number_frames - 1) < 0: - direction = 1 - # get the neighbor list - if direction == 1: - neighbor_list = list( - range(center_frame_idx, center_frame_idx + interval * number_frames, interval)) - else: - neighbor_list = list( - range(center_frame_idx, center_frame_idx - interval * number_frames, -interval)) - name_b = '{:08d}'.format(neighbor_list[0]) - else: - # ensure not exceeding the borders - while (center_frame_idx + half_N_frames * interval > - max_frame) or (center_frame_idx - half_N_frames * interval < 0): - center_frame_idx = random.randint(0, max_frame) - # get the neighbor list - neighbor_list = list( - range(center_frame_idx - half_N_frames * interval, - center_frame_idx + half_N_frames * interval + 1, interval)) - if random_reverse and random.random() < 0.5: - neighbor_list.reverse() - name_b = '{:08d}'.format(neighbor_list[half_N_frames]) - assert len(neighbor_list) == number_frames, \ - "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) - - return neighbor_list, name_b - - -def read_img(path, size=None, is_gt=False): - """read image by cv2 - return: Numpy float32, HWC, BGR, [0,1]""" - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - #if not is_gt: - # #print(path) - # img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25) - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - -def img_augment(img_list, hflip=True, rot=True): - """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - -def make_reader(filelist, - num_threads, - batch_size, - is_training, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - LR_input, - crop_size, - scale, - use_flip, - use_rot, - mode='train'): - fl = filelist - def reader_(): - if is_training: - random.shuffle(fl) - batch_out = [] - for item in fl: - img_LQs, img_GT = get_sample_data(item, - number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot, mode) - videoname = item.split('_')[0] - framename = item.split('_')[1] - if (mode == 'train') or (mode == 'valid'): - batch_out.append((img_LQs, img_GT)) - elif mode == 'test': - batch_out.append((img_LQs, img_GT, videoname, framename)) - else: - raise NotImplementedError("mode {} not implemented".format(mode)) - if len(batch_out) == batch_size: - yield batch_out - batch_out = [] - return reader_ - - -def make_multi_reader(filelist, - num_threads, - batch_size, - is_training, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - LR_input, - crop_size, - scale, - use_flip, - use_rot, - mode='train'): - def read_into_queue(flq, queue): - batch_out = [] - for item in flq: - img_LQs, img_GT = get_sample_data(item, - number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot, mode) - videoname = item.split('_')[0] - framename = item.split('_')[1] - if (mode == 'train') or (mode == 'valid'): - batch_out.append((img_LQs, img_GT)) - elif mode == 'test': - batch_out.append((img_LQs, img_GT, videoname, framename)) - else: - raise NotImplementedError("mode {} not implemented".format(mode)) - if len(batch_out) == batch_size: - queue.put(batch_out) - batch_out = [] - queue.put(None) - - - def queue_reader(): - fl = filelist - if is_training: - random.shuffle(fl) - - n = num_threads - queue_size = 20 - reader_lists = [None] * n - file_num = int(len(fl) // n) - for i in range(n): - if i < len(reader_lists) - 1: - tmp_list = fl[i * file_num:(i + 1) * file_num] - else: - tmp_list = fl[i * file_num:] - reader_lists[i] = tmp_list - - queue = multiprocessing.Queue(queue_size) - p_list = [None] * len(reader_lists) - # for reader_list in reader_lists: - for i in range(len(reader_lists)): - reader_list = reader_lists[i] - p_list[i] = multiprocessing.Process( - target=read_into_queue, args=(reader_list, queue)) - p_list[i].start() - reader_num = len(reader_lists) - finish_num = 0 - while finish_num < reader_num: - sample = queue.get() - if sample is None: - finish_num += 1 - else: - yield sample - for i in range(len(p_list)): - if p_list[i].is_alive(): - p_list[i].join() - - return queue_reader diff --git a/applications/EDVR/reader/edvr_reader.pybk b/applications/EDVR/reader/edvr_reader.pybk deleted file mode 100644 index bb45925..0000000 --- a/applications/EDVR/reader/edvr_reader.pybk +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -import os -import sys -import cv2 -import math -import random -import multiprocessing -import functools -import numpy as np -import paddle -import cv2 -import logging -from .reader_utils import DataReader - -logger = logging.getLogger(__name__) -python_ver = sys.version_info - -random.seed(0) -np.random.seed(0) - -class EDVRReader(DataReader): - """ - Data reader for video super resolution task fit for EDVR model. - This is specified for REDS dataset. - """ - def __init__(self, name, mode, cfg): - super(EDVRReader, self).__init__(name, mode, cfg) - self.format = cfg.MODEL.format - self.crop_size = self.get_config_from_sec(mode, 'crop_size') - self.interval_list = self.get_config_from_sec(mode, 'interval_list') - self.random_reverse = self.get_config_from_sec(mode, 'random_reverse') - self.number_frames = self.get_config_from_sec(mode, 'number_frames') - # set batch size and file list - self.batch_size = cfg[mode.upper()]['batch_size'] - self.fileroot = cfg[mode.upper()]['file_root'] - self.use_flip = self.get_config_from_sec(mode, 'use_flip', False) - self.use_rot = self.get_config_from_sec(mode, 'use_rot', False) - - self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1) - self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024) - self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False) - - if self.mode != 'infer': - self.gtroot = self.get_config_from_sec(mode, 'gt_root') - self.scale = self.get_config_from_sec(mode, 'scale', 1) - self.LR_input = (self.scale > 1) - if self.fix_random_seed: - random.seed(0) - np.random.seed(0) - self.num_reader_threads = 1 - - """ - def create_reader(self): - logger.info('initialize reader ... ') - self.filelist = [] - for video_name in os.listdir(self.fileroot): - if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']): - continue - for frame_name in os.listdir(os.path.join(self.fileroot, video_name)): - frame_idx = frame_name.split('.')[0] - video_frame_idx = video_name + '_' + frame_idx - # for each item in self.filelist is like '010_00000015', '260_00000090' - self.filelist.append(video_frame_idx) - - #self.filelist.sort() - - def reader_(): - ### not implemented border mode, maybe add later ############ - if self.mode == 'train': - random.shuffle(self.filelist) - for item in self.filelist: - #print(item) - video_name = item.split('_')[0] - frame_name = item.split('_')[1] - ngb_frames, name_b = get_neighbor_frames(frame_name, \ - number_frames = self.number_frames, \ - interval_list = self.interval_list, \ - random_reverse = self.random_reverse) - frame_name = name_b - #print('key2', ngb_frames, name_b) - img_GT = read_img(os.path.join(self.gtroot, video_name, frame_name + '.png')) - #print('gt_mean', np.mean(img_GT)) - frame_list = [] - for ngb_frm in ngb_frames: - ngb_name = "%08d"%ngb_frm - #img = read_img(os.path.join(self.fileroot, video_name, frame_name + '.png')) - img = read_img(os.path.join(self.fileroot, video_name, ngb_name + '.png')) - frame_list.append(img) - #print('img_mean', np.mean(img)) - - H, W, C = frame_list[0].shape - # add random crop - if self.LR_input: - LQ_size = self.crop_size // self.scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - #print('rnd_h {}, rnd_w {}', rnd_h, rnd_w) - frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list] - rnd_h_HR, rnd_w_HR = int(rnd_h * self.scale), int(rnd_w * self.scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + self.crop_size, rnd_w_HR:rnd_w_HR + self.crop_size, :] - else: - rnd_h = random.randint(0, max(0, H - self.crop_size)) - rnd_w = random.randint(0, max(0, W - self.crop_size)) - frame_list = [v[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :] for v in frame_list] - img_GT = img_GT[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :] - - # add random flip and rotation - frame_list.append(img_GT) - rlt = img_augment(frame_list, self.use_flip, self.use_rot) - frame_list = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(frame_list, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') - img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') - yield img_LQs, img_GT - - def _batch_reader(): - batch_out = [] - for img_LQs, img_GT in reader_(): - #print('lq', img_LQs.shape) - #print('gt', img_GT.shape) - batch_out.append((img_LQs, img_GT)) - if len(batch_out) == self.batch_size: - yield batch_out - batch_out = [] - - return _batch_reader - """ - - def create_reader(self): - logger.info('initialize reader ... ') - self.filelist = [] - for video_name in os.listdir(self.fileroot): - if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']): - continue - for frame_name in os.listdir(os.path.join(self.fileroot, video_name)): - frame_idx = frame_name.split('.')[0] - video_frame_idx = video_name + '_' + frame_idx - # for each item in self.filelist is like '010_00000015', '260_00000090' - self.filelist.append(video_frame_idx) - - if self.num_reader_threads == 1: - reader_func = make_reader - else: - reader_func = make_multi_reader - - - return reader_func(filelist = self.filelist, - num_threads = self.num_reader_threads, - batch_size = self.batch_size, - is_training = (self.mode == 'train'), - number_frames = self.number_frames, - interval_list = self.interval_list, - random_reverse = self.random_reverse, - gtroot = self.gtroot, - fileroot = self.fileroot, - LR_input = self.LR_input, - crop_size = self.crop_size, - scale = self.scale, - use_flip = self.use_flip, - use_rot = self.use_rot) - - -def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot): - video_name = item.split('_')[0] - frame_name = item.split('_')[1] - ngb_frames, name_b = get_neighbor_frames(frame_name, \ - number_frames = number_frames, \ - interval_list = interval_list, \ - random_reverse = random_reverse) - frame_name = name_b - #print('key2', ngb_frames, name_b) - img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png')) - #print('gt_mean', np.mean(img_GT)) - frame_list = [] - for ngb_frm in ngb_frames: - ngb_name = "%08d"%ngb_frm - #img = read_img(os.path.join(fileroot, video_name, frame_name + '.png')) - img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png')) - frame_list.append(img) - #print('img_mean', np.mean(img)) - - H, W, C = frame_list[0].shape - # add random crop - if LR_input: - LQ_size = crop_size // scale - rnd_h = random.randint(0, max(0, H - LQ_size)) - rnd_w = random.randint(0, max(0, W - LQ_size)) - #print('rnd_h {}, rnd_w {}', rnd_h, rnd_w) - frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list] - rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) - img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :] - else: - rnd_h = random.randint(0, max(0, H - crop_size)) - rnd_w = random.randint(0, max(0, W - crop_size)) - frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list] - img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] - - # add random flip and rotation - frame_list.append(img_GT) - rlt = img_augment(frame_list, use_flip, use_rot) - frame_list = rlt[0:-1] - img_GT = rlt[-1] - - # stack LQ images to NHWC, N is the frame number - img_LQs = np.stack(frame_list, axis=0) - # BGR to RGB, HWC to CHW, numpy to tensor - img_GT = img_GT[:, :, [2, 1, 0]] - img_LQs = img_LQs[:, :, :, [2, 1, 0]] - img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') - img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') - - return img_LQs, img_GT - - -def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, bordermode=False): - center_frame_idx = int(frame_name) - half_N_frames = number_frames // 2 - #### determine the neighbor frames - interval = random.choice(interval_list) - if bordermode: - direction = 1 # 1: forward; 0: backward - if random_reverse and random.random() < 0.5: - direction = random.choice([0, 1]) - if center_frame_idx + interval * (number_frames - 1) > 99: - direction = 0 - elif center_frame_idx - interval * (number_frames - 1) < 0: - direction = 1 - # get the neighbor list - if direction == 1: - neighbor_list = list( - range(center_frame_idx, center_frame_idx + interval * number_frames, interval)) - else: - neighbor_list = list( - range(center_frame_idx, center_frame_idx - interval * number_frames, -interval)) - name_b = '{:08d}'.format(neighbor_list[0]) - else: - # ensure not exceeding the borders - while (center_frame_idx + half_N_frames * interval > - 99) or (center_frame_idx - half_N_frames * interval < 0): - center_frame_idx = random.randint(0, 99) - # get the neighbor list - neighbor_list = list( - range(center_frame_idx - half_N_frames * interval, - center_frame_idx + half_N_frames * interval + 1, interval)) - if random_reverse and random.random() < 0.5: - neighbor_list.reverse() - name_b = '{:08d}'.format(neighbor_list[half_N_frames]) - assert len(neighbor_list) == number_frames, \ - "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) - - return neighbor_list, name_b - - -def read_img(path, size=None): - """read image by cv2 - return: Numpy float32, HWC, BGR, [0,1]""" - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - -def img_augment(img_list, hflip=True, rot=True): - """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - -def make_reader(filelist, - num_threads, - batch_size, - is_training, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - LR_input, - crop_size, - scale, - use_flip, - use_rot): - fl = filelist - def reader_(): - if is_training: - random.shuffle(fl) - batch_out = [] - for item in fl: - img_LQs, img_GT = get_sample_data(item, - number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot) - batch_out.append((img_LQs, img_GT)) - if len(batch_out) == batch_size: - yield batch_out - batch_out = [] - return reader_ - - -def make_multi_reader(filelist, - num_threads, - batch_size, - is_training, - number_frames, - interval_list, - random_reverse, - gtroot, - fileroot, - LR_input, - crop_size, - scale, - use_flip, - use_rot): - def read_into_queue(flq, queue): - batch_out = [] - for item in flq: - img_LQs, img_GT = get_sample_data(item, - number_frames, interval_list, random_reverse, gtroot, fileroot, - LR_input, crop_size, scale, use_flip, use_rot) - batch_out.append((img_LQs, img_GT)) - if len(batch_out) == batch_size: - queue.put(batch_out) - batch_out = [] - queue.put(None) - - - def queue_reader(): - fl = filelist - if is_training: - random.shuffle(fl) - - n = num_threads - queue_size = 20 - reader_lists = [None] * n - file_num = int(len(fl) // n) - for i in range(n): - if i < len(reader_lists) - 1: - tmp_list = fl[i * file_num:(i + 1) * file_num] - else: - tmp_list = fl[i * file_num:] - reader_lists[i] = tmp_list - - queue = multiprocessing.Queue(queue_size) - p_list = [None] * len(reader_lists) - # for reader_list in reader_lists: - for i in range(len(reader_lists)): - reader_list = reader_lists[i] - p_list[i] = multiprocessing.Process( - target=read_into_queue, args=(reader_list, queue)) - p_list[i].start() - reader_num = len(reader_lists) - finish_num = 0 - while finish_num < reader_num: - sample = queue.get() - if sample is None: - finish_num += 1 - else: - yield sample - for i in range(len(p_list)): - if p_list[i].is_alive(): - p_list[i].join() - - return queue_reader diff --git a/applications/EDVR/run.sh b/applications/EDVR/run.sh index 9702462..90b939c 100644 --- a/applications/EDVR/run.sh +++ b/applications/EDVR/run.sh @@ -11,73 +11,24 @@ mode=$1 name=$2 configs=$3 -pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed -resume="" # set pretrain model path if needed -save_dir="./data/checkpoints" +#pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed +#resume="" # set pretrain model path if needed +#save_dir="./data/checkpoints" save_inference_dir="./data/inference_model" use_gpu=True fix_random_seed=False log_interval=1 valid_interval=1 -#weights="./data/checkpoints/EDVR_epoch721.pdparams" #set the path of weights to enable eval and predicut, just ignore this when training - -#weights="./data/checkpoints_with_tsa/EDVR_epoch821.pdparams" weights="./weights/paddle_state_dict_L.npz" -#weights="./weights/" -#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=4,5,6,7 #0,1,5,6 fast, 2,3,4,7 slow -#export CUDA_VISIBLE_DEVICES=7 export FLAGS_fast_eager_deletion_mode=1 export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_fraction_of_gpu_memory_to_use=0.98 -if [ "$mode" == "train" ]; then - echo $mode $name $configs $resume $pretrain - if [ "$resume"x != ""x ]; then - python train.py --model_name=$name \ - --config=$configs \ - --resume=$resume \ - --log_interval=$log_interval \ - --valid_interval=$valid_interval \ - --use_gpu=$use_gpu \ - --save_dir=$save_dir \ - --fix_random_seed=$fix_random_seed - elif [ "$pretrain"x != ""x ]; then - python train.py --model_name=$name \ - --config=$configs \ - --pretrain=$pretrain \ - --log_interval=$log_interval \ - --valid_interval=$valid_interval \ - --use_gpu=$use_gpu \ - --save_dir=$save_dir \ - --fix_random_seed=$fix_random_seed - else - python train.py --model_name=$name \ - --config=$configs \ - --log_interval=$log_interval \ - --valid_interval=$valid_interval \ - --use_gpu=$use_gpu \ - --save_dir=$save_dir \ - --fix_random_seed=$fix_random_seed - fi -elif [ "$mode"x == "eval"x ]; then - echo $mode $name $configs $weights - if [ "$weights"x != ""x ]; then - python eval.py --model_name=$name \ - --config=$configs \ - --log_interval=$log_interval \ - --weights=$weights \ - --use_gpu=$use_gpu - else - python eval.py --model_name=$name \ - --config=$configs \ - --log_interval=$log_interval \ - --use_gpu=$use_gpu - fi -elif [ "$mode"x == "predict"x ]; then +if [ "$mode"x == "predict"x ]; then echo $mode $name $configs $weights if [ "$weights"x != ""x ]; then python predict.py --model_name=$name \ -- GitLab