Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
8c46c443
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c46c443
编写于
8月 12, 2020
作者:
L
lijianshe02
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete inference irralated code
上级
39937404
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
4 addition
and
1022 deletion
+4
-1022
applications/EDVR/configs/edvr_L.yaml
applications/EDVR/configs/edvr_L.yaml
+0
-37
applications/EDVR/models/__init__.py
applications/EDVR/models/__init__.py
+0
-27
applications/EDVR/models/edvr/README.md
applications/EDVR/models/edvr/README.md
+0
-112
applications/EDVR/reader/edvr_reader.py.bak
applications/EDVR/reader/edvr_reader.py.bak
+0
-395
applications/EDVR/reader/edvr_reader.pybk
applications/EDVR/reader/edvr_reader.pybk
+0
-398
applications/EDVR/run.sh
applications/EDVR/run.sh
+4
-53
未找到文件。
applications/EDVR/configs/edvr_L.yaml
浏览文件 @
8c46c443
...
...
@@ -11,43 +11,6 @@ MODEL:
HR_in
:
False
w_TSA
:
True
#False
TRAIN
:
epoch
:
45
use_gpu
:
True
num_gpus
:
4
#8
scale
:
4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
32
file_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp_bicubic/X4"
gt_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp"
use_flip
:
True
use_rot
:
True
base_learning_rate
:
0.0004
l2_weight_decay
:
0.0
TSA_only
:
False
T_periods
:
[
50000
,
100000
,
150000
,
150000
,
150000
]
# for cosine annealing restart
restarts
:
[
50000
,
150000
,
300000
,
450000
]
# for cosine annealing restart
weights
:
[
1
,
1
,
1
,
1
]
# for cosine annealing restart
eta_min
:
1e-7
# for cosine annealing restart
num_reader_threads
:
8
buf_size
:
1024
fix_random_seed
:
False
VALID
:
scale
:
4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
32
#256
file_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root
:
"
/workspace/video_test/video/data/dataset/edvr/edvr/REDS4/GT"
use_flip
:
False
use_rot
:
False
TEST
:
scale
:
4
crop_size
:
256
...
...
applications/EDVR/models/__init__.py
浏览文件 @
8c46c443
from
.model
import
regist_model
,
get_model
#from .attention_cluster import AttentionCluster
#from .attention_lstm import AttentionLSTM
#from .nextvlad import NEXTVLAD
#from .nonlocal_model import NonLocal
#from .tsm import TSM
#from .tsn import TSN
#from .stnet import STNET
#from .ctcn import CTCN
#from .bmn import BMN
#from .bsn import BsnTem
#from .bsn import BsnPem
#from .ets import ETS
#from .tall import TALL
from
.edvr
import
EDVR
# regist models, sort by alphabet
#regist_model("AttentionCluster", AttentionCluster)
#regist_model("AttentionLSTM", AttentionLSTM)
#regist_model("NEXTVLAD", NEXTVLAD)
#regist_model('NONLOCAL', NonLocal)
#regist_model("TSM", TSM)
#regist_model("TSN", TSN)
#regist_model("STNET", STNET)
#regist_model("CTCN", CTCN)
#regist_model("BMN", BMN)
#regist_model("BsnTem", BsnTem)
#regist_model("BsnPem", BsnPem)
#regist_model("ETS", ETS)
#regist_model("TALL", TALL)
regist_model
(
"EDVR"
,
EDVR
)
applications/EDVR/models/edvr/README.md
已删除
100644 → 0
浏览文件 @
39937404
# TSN 视频分类模型
---
## 内容
-
[
模型简介
](
#模型简介
)
-
[
数据准备
](
#数据准备
)
-
[
模型训练
](
#模型训练
)
-
[
模型评估
](
#模型评估
)
-
[
模型推断
](
#模型推断
)
-
[
参考论文
](
#参考论文
)
## 模型简介
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。
详细内容请参考ECCV 2016年论文
[
Temporal Segment Networks: Towards Good Practices for Deep Action Recognition
](
https://arxiv.org/abs/1608.00859
)
## 数据准备
TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考
[
数据说明
](
../../data/dataset/README.md
)
## 模型训练
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python train.py --model_name=TSN
\
--config=./configs/tsn.yaml
\
--log_interval=10
\
--valid_interval=1
\
--use_gpu=True
\
--save_dir=./data/checkpoints
\
--fix_random_seed=False
\
--pretrain=$PATH_TO_PRETRAIN_MODEL
bash run.sh train TSN ./configs/tsn.yaml
-
从头开始训练,需要加载在ImageNet上训练的ResNet50权重作为初始化参数,请下载此
[
模型参数
](
https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz
)
并解压,将上面启动命令行或者run.sh脚本中的
`pretrain`
参数设置为解压之后的模型参数
存放路径。如果没有手动下载并设置
`pretrain`
参数,则程序会自动下载并将参数保存在~/.paddle/weights/ResNet50
\_
pretrained目录下面
-
可下载已发布模型
[
model
](
https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams
)
通过
`--resume`
指定权重存
放路径进行finetune等开发
**数据读取器说明:**
模型读取Kinetics-400数据集中的
`mp4`
数据,每条数据抽取
`seg_num`
段,每段抽取1帧图像,对每帧图像做随机增强后,缩放至
`target_size`
。
**训练策略:**
*
采用Momentum优化算法训练,momentum=0.9
*
权重衰减系数为1e-4
*
学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减
## 模型评估
可通过如下两种方式进行模型评估:
python eval.py --model_name=TSN
\
--config=./configs/tsn.yaml
\
--log_interval=1
\
--weights=$PATH_TO_WEIGHTS
\
--use_gpu=True
bash run.sh eval TSN ./configs/tsn.yaml
-
使用
`run.sh`
进行评估时,需要修改脚本中的
`weights`
参数指定需要评估的权重
-
若未指定
`--weights`
参数,脚本会下载已发布模型
[
model
](
https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams
)
进行评估
-
评估结果以log的形式直接打印输出TOP1
\_
ACC、TOP5
\_
ACC等精度指标
-
使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的
`use_gpu`
设置为False
当取如下参数时,在Kinetics400的validation数据集下评估精度如下:
| seg
\_
num | target
\_
size | Top-1 |
| :------: | :----------: | :----: |
| 3 | 224 | 0.66 |
| 7 | 224 | 0.67 |
## 模型推断
可通过如下两种方式启动模型推断:
python predict.py --model_name=TSN
\
--config=./configs/tsn.yaml
\
--log_interval=1
\
--weights=$PATH_TO_WEIGHTS
\
--filelist=$FILELIST
\
--use_gpu=True
\
--video_path=$VIDEO_PATH
bash run.sh predict TSN ./configs/tsn.yaml
-
使用
`run.sh`
进行评估时,需要修改脚本中的
`weights`
参数指定需要用到的权重。
-
如果video
\_
path为'', 则忽略掉此参数。如果video
\_
path != '',则程序会对video
\_
path指定的视频文件进行预测,而忽略掉filelist的值,预测结果为此视频的分类概率。
-
若未指定
`--weights`
参数,脚本会下载已发布模型
[
model
](
https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams
)
进行推断
-
模型推断结果以log的形式直接打印输出,可以看到测试样本的分类预测概率。
-
使用CPU进行推断时,请将命令行或者run.sh脚本中的
`use_gpu`
设置为False
## 参考论文
-
[
Temporal Segment Networks: Towards Good Practices for Deep Action Recognition
](
https://arxiv.org/abs/1608.00859
)
, Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
applications/EDVR/reader/edvr_reader.py.bak
已删除
100644 → 0
浏览文件 @
39937404
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test':
self.filelist.sort()
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot,
mode = self.mode)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
elif mode == 'test':
ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True)
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
max_frame) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None, is_gt=False):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
applications/EDVR/reader/edvr_reader.pybk
已删除
100644 → 0
浏览文件 @
39937404
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
#self.filelist.sort()
def reader_():
### not implemented border mode, maybe add later ############
if self.mode == 'train':
random.shuffle(self.filelist)
for item in self.filelist:
#print(item)
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = self.number_frames, \
interval_list = self.interval_list, \
random_reverse = self.random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(self.gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(self.fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(self.fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if self.LR_input:
LQ_size = self.crop_size // self.scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * self.scale), int(rnd_w * self.scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + self.crop_size, rnd_w_HR:rnd_w_HR + self.crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - self.crop_size))
rnd_w = random.randint(0, max(0, W - self.crop_size))
frame_list = [v[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, self.use_flip, self.use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
yield img_LQs, img_GT
def _batch_reader():
batch_out = []
for img_LQs, img_GT in reader_():
#print('lq', img_LQs.shape)
#print('gt', img_GT.shape)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, use_flip, use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > 99:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
99) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, 99)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
applications/EDVR/run.sh
浏览文件 @
8c46c443
...
...
@@ -11,73 +11,24 @@ mode=$1
name
=
$2
configs
=
$3
pretrain
=
"./tmp/name_map/paddle_state_dict.npz"
# set pretrain model path if needed
resume
=
""
# set pretrain model path if needed
save_dir
=
"./data/checkpoints"
#
pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed
#
resume="" # set pretrain model path if needed
#
save_dir="./data/checkpoints"
save_inference_dir
=
"./data/inference_model"
use_gpu
=
True
fix_random_seed
=
False
log_interval
=
1
valid_interval
=
1
#weights="./data/checkpoints/EDVR_epoch721.pdparams" #set the path of weights to enable eval and predicut, just ignore this when training
#weights="./data/checkpoints_with_tsa/EDVR_epoch821.pdparams"
weights
=
"./weights/paddle_state_dict_L.npz"
#weights="./weights/"
#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
4,5,6,7
#0,1,5,6 fast, 2,3,4,7 slow
#export CUDA_VISIBLE_DEVICES=7
export
FLAGS_fast_eager_deletion_mode
=
1
export
FLAGS_eager_delete_tensor_gb
=
0.0
export
FLAGS_fraction_of_gpu_memory_to_use
=
0.98
if
[
"
$mode
"
==
"train"
]
;
then
echo
$mode
$name
$configs
$resume
$pretrain
if
[
"
$resume
"
x
!=
""
x
]
;
then
python train.py
--model_name
=
$name
\
--config
=
$configs
\
--resume
=
$resume
\
--log_interval
=
$log_interval
\
--valid_interval
=
$valid_interval
\
--use_gpu
=
$use_gpu
\
--save_dir
=
$save_dir
\
--fix_random_seed
=
$fix_random_seed
elif
[
"
$pretrain
"
x
!=
""
x
]
;
then
python train.py
--model_name
=
$name
\
--config
=
$configs
\
--pretrain
=
$pretrain
\
--log_interval
=
$log_interval
\
--valid_interval
=
$valid_interval
\
--use_gpu
=
$use_gpu
\
--save_dir
=
$save_dir
\
--fix_random_seed
=
$fix_random_seed
else
python train.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--valid_interval
=
$valid_interval
\
--use_gpu
=
$use_gpu
\
--save_dir
=
$save_dir
\
--fix_random_seed
=
$fix_random_seed
fi
elif
[
"
$mode
"
x
==
"eval"
x
]
;
then
echo
$mode
$name
$configs
$weights
if
[
"
$weights
"
x
!=
""
x
]
;
then
python eval.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--weights
=
$weights
\
--use_gpu
=
$use_gpu
else
python eval.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--use_gpu
=
$use_gpu
fi
elif
[
"
$mode
"
x
==
"predict"
x
]
;
then
if
[
"
$mode
"
x
==
"predict"
x
]
;
then
echo
$mode
$name
$configs
$weights
if
[
"
$weights
"
x
!=
""
x
]
;
then
python predict.py
--model_name
=
$name
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录