提交 8c46c443 编写于 作者: L lijianshe02

delete inference irralated code

上级 39937404
......@@ -11,43 +11,6 @@ MODEL:
HR_in: False
w_TSA: True #False
TRAIN:
epoch: 45
use_gpu: True
num_gpus: 4 #8
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32
file_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp_bicubic/X4"
gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp"
use_flip: True
use_rot: True
base_learning_rate: 0.0004
l2_weight_decay: 0.0
TSA_only: False
T_periods: [50000, 100000, 150000, 150000, 150000] # for cosine annealing restart
restarts: [50000, 150000, 300000, 450000] # for cosine annealing restart
weights: [1, 1, 1, 1] # for cosine annealing restart
eta_min: 1e-7 # for cosine annealing restart
num_reader_threads: 8
buf_size: 1024
fix_random_seed: False
VALID:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32 #256
file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root: "/workspace/video_test/video/data/dataset/edvr/edvr/REDS4/GT"
use_flip: False
use_rot: False
TEST:
scale: 4
crop_size: 256
......
from .model import regist_model, get_model
#from .attention_cluster import AttentionCluster
#from .attention_lstm import AttentionLSTM
#from .nextvlad import NEXTVLAD
#from .nonlocal_model import NonLocal
#from .tsm import TSM
#from .tsn import TSN
#from .stnet import STNET
#from .ctcn import CTCN
#from .bmn import BMN
#from .bsn import BsnTem
#from .bsn import BsnPem
#from .ets import ETS
#from .tall import TALL
from .edvr import EDVR
# regist models, sort by alphabet
#regist_model("AttentionCluster", AttentionCluster)
#regist_model("AttentionLSTM", AttentionLSTM)
#regist_model("NEXTVLAD", NEXTVLAD)
#regist_model('NONLOCAL', NonLocal)
#regist_model("TSM", TSM)
#regist_model("TSN", TSN)
#regist_model("STNET", STNET)
#regist_model("CTCN", CTCN)
#regist_model("BMN", BMN)
#regist_model("BsnTem", BsnTem)
#regist_model("BsnPem", BsnPem)
#regist_model("ETS", ETS)
#regist_model("TALL", TALL)
regist_model("EDVR", EDVR)
# TSN 视频分类模型
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。
详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859)
## 数据准备
TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](../../data/dataset/README.md)
## 模型训练
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python train.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=10 \
--valid_interval=1 \
--use_gpu=True \
--save_dir=./data/checkpoints \
--fix_random_seed=False \
--pretrain=$PATH_TO_PRETRAIN_MODEL
bash run.sh train TSN ./configs/tsn.yaml
- 从头开始训练,需要加载在ImageNet上训练的ResNet50权重作为初始化参数,请下载此[模型参数](https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz)并解压,将上面启动命令行或者run.sh脚本中的`pretrain`参数设置为解压之后的模型参数
存放路径。如果没有手动下载并设置`pretrain`参数,则程序会自动下载并将参数保存在~/.paddle/weights/ResNet50\_pretrained目录下面
- 可下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)通过`--resume`指定权重存
放路径进行finetune等开发
**数据读取器说明:** 模型读取Kinetics-400数据集中的`mp4`数据,每条数据抽取`seg_num`段,每段抽取1帧图像,对每帧图像做随机增强后,缩放至`target_size`
**训练策略:**
* 采用Momentum优化算法训练,momentum=0.9
* 权重衰减系数为1e-4
* 学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减
## 模型评估
可通过如下两种方式进行模型评估:
python eval.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--use_gpu=True
bash run.sh eval TSN ./configs/tsn.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行评估
- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标
- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False
当取如下参数时,在Kinetics400的validation数据集下评估精度如下:
| seg\_num | target\_size | Top-1 |
| :------: | :----------: | :----: |
| 3 | 224 | 0.66 |
| 7 | 224 | 0.67 |
## 模型推断
可通过如下两种方式启动模型推断:
python predict.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST \
--use_gpu=True \
--video_path=$VIDEO_PATH
bash run.sh predict TSN ./configs/tsn.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。
- 如果video\_path为'', 则忽略掉此参数。如果video\_path != '',则程序会对video\_path指定的视频文件进行预测,而忽略掉filelist的值,预测结果为此视频的分类概率。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行推断
- 模型推断结果以log的形式直接打印输出,可以看到测试样本的分类预测概率。
- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False
## 参考论文
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test':
self.filelist.sort()
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot,
mode = self.mode)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
elif mode == 'test':
ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True)
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
max_frame) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None, is_gt=False):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
#self.filelist.sort()
def reader_():
### not implemented border mode, maybe add later ############
if self.mode == 'train':
random.shuffle(self.filelist)
for item in self.filelist:
#print(item)
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = self.number_frames, \
interval_list = self.interval_list, \
random_reverse = self.random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(self.gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(self.fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(self.fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if self.LR_input:
LQ_size = self.crop_size // self.scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * self.scale), int(rnd_w * self.scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + self.crop_size, rnd_w_HR:rnd_w_HR + self.crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - self.crop_size))
rnd_w = random.randint(0, max(0, W - self.crop_size))
frame_list = [v[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, self.use_flip, self.use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
yield img_LQs, img_GT
def _batch_reader():
batch_out = []
for img_LQs, img_GT in reader_():
#print('lq', img_LQs.shape)
#print('gt', img_GT.shape)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, use_flip, use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > 99:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
99) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, 99)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
......@@ -11,73 +11,24 @@ mode=$1
name=$2
configs=$3
pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed
resume="" # set pretrain model path if needed
save_dir="./data/checkpoints"
#pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed
#resume="" # set pretrain model path if needed
#save_dir="./data/checkpoints"
save_inference_dir="./data/inference_model"
use_gpu=True
fix_random_seed=False
log_interval=1
valid_interval=1
#weights="./data/checkpoints/EDVR_epoch721.pdparams" #set the path of weights to enable eval and predicut, just ignore this when training
#weights="./data/checkpoints_with_tsa/EDVR_epoch821.pdparams"
weights="./weights/paddle_state_dict_L.npz"
#weights="./weights/"
#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=4,5,6,7 #0,1,5,6 fast, 2,3,4,7 slow
#export CUDA_VISIBLE_DEVICES=7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
if [ "$mode" == "train" ]; then
echo $mode $name $configs $resume $pretrain
if [ "$resume"x != ""x ]; then
python train.py --model_name=$name \
--config=$configs \
--resume=$resume \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
elif [ "$pretrain"x != ""x ]; then
python train.py --model_name=$name \
--config=$configs \
--pretrain=$pretrain \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
else
python train.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
fi
elif [ "$mode"x == "eval"x ]; then
echo $mode $name $configs $weights
if [ "$weights"x != ""x ]; then
python eval.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--weights=$weights \
--use_gpu=$use_gpu
else
python eval.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--use_gpu=$use_gpu
fi
elif [ "$mode"x == "predict"x ]; then
if [ "$mode"x == "predict"x ]; then
echo $mode $name $configs $weights
if [ "$weights"x != ""x ]; then
python predict.py --model_name=$name \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册