Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
031e15f1
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
031e15f1
编写于
8月 13, 2020
作者:
L
LielinJiang
提交者:
GitHub
8月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7 from lijianshe02/master
add edvr inference code
上级
38701e31
015a8f06
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
904 addition
and
0 deletion
+904
-0
applications/EDVR/configs/edvr_L.yaml
applications/EDVR/configs/edvr_L.yaml
+24
-0
applications/EDVR/predict.py
applications/EDVR/predict.py
+174
-0
applications/EDVR/reader/__init__.py
applications/EDVR/reader/__init__.py
+4
-0
applications/EDVR/reader/edvr_reader.py
applications/EDVR/reader/edvr_reader.py
+434
-0
applications/EDVR/reader/reader_utils.py
applications/EDVR/reader/reader_utils.py
+81
-0
applications/EDVR/run.sh
applications/EDVR/run.sh
+41
-0
applications/EDVR/utils/__init__.py
applications/EDVR/utils/__init__.py
+0
-0
applications/EDVR/utils/config_utils.py
applications/EDVR/utils/config_utils.py
+75
-0
applications/EDVR/utils/utility.py
applications/EDVR/utils/utility.py
+71
-0
未找到文件。
applications/EDVR/configs/edvr_L.yaml
0 → 100644
浏览文件 @
031e15f1
MODEL
:
name
:
"
EDVR"
format
:
"
png"
num_frames
:
5
center
:
2
num_filters
:
128
#64
deform_conv_groups
:
8
front_RBs
:
5
back_RBs
:
40
#10
predeblur
:
False
HR_in
:
False
w_TSA
:
True
#False
INFER
:
scale
:
4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
file_root
:
"
/workspace/color/input_frames"
inference_model
:
"
/workspace/PaddleGAN/applications/EDVR/data/inference_model"
use_flip
:
False
use_rot
:
False
applications/EDVR/predict.py
0 → 100644
浏览文件 @
031e15f1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
sys
import
time
import
logging
import
argparse
import
ast
import
numpy
as
np
try
:
import
cPickle
as
pickle
except
:
import
pickle
import
paddle.fluid
as
fluid
import
cv2
from
utils.config_utils
import
*
#import models
from
reader
import
get_reader
#from metrics import get_metrics
from
utils.utility
import
check_cuda
from
utils.utility
import
check_version
logging
.
root
.
handlers
=
[]
FORMAT
=
'[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
FORMAT
,
stream
=
sys
.
stdout
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'AttentionCluster'
,
help
=
'name of model to train.'
)
parser
.
add_argument
(
'--inference_model'
,
type
=
str
,
default
=
'./data/inference_model'
,
help
=
'path of inference_model.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
default
=
'configs/attention_cluster.txt'
,
help
=
'path to config file of model'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'default use gpu.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'sample number in a batch for inference.'
)
parser
.
add_argument
(
'--filelist'
,
type
=
str
,
default
=
None
,
help
=
'path to inferenece data file lists file.'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
1
,
help
=
'mini-batch interval to log.'
)
parser
.
add_argument
(
'--infer_topk'
,
type
=
int
,
default
=
20
,
help
=
'topk predictions to restore.'
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
os
.
path
.
join
(
'data'
,
'predict_results'
),
help
=
'directory to store results'
)
parser
.
add_argument
(
'--video_path'
,
type
=
str
,
default
=
None
,
help
=
'directory to store results'
)
args
=
parser
.
parse_args
()
return
args
def
get_img
(
pred
):
print
(
'pred shape'
,
pred
.
shape
)
pred
=
pred
.
squeeze
()
pred
=
np
.
clip
(
pred
,
a_min
=
0.
,
a_max
=
1.0
)
pred
=
pred
*
255
pred
=
pred
.
round
()
pred
=
pred
.
astype
(
'uint8'
)
pred
=
np
.
transpose
(
pred
,
(
1
,
2
,
0
))
# chw -> hwc
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
return
pred
def
save_img
(
img
,
framename
):
dirname
=
'./demo/resultpng'
filename
=
os
.
path
.
join
(
dirname
,
framename
+
'.png'
)
cv2
.
imwrite
(
filename
,
img
)
def
infer
(
args
):
# parse config
config
=
parse_config
(
args
.
config
)
infer_config
=
merge_configs
(
config
,
'infer'
,
vars
(
args
))
print_configs
(
infer_config
,
"Infer"
)
inference_model
=
args
.
inference_model
model_filename
=
'EDVR_model.pdmodel'
params_filename
=
'EDVR_params.pdparams'
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
[
inference_program
,
feed_list
,
fetch_list
]
=
fluid
.
io
.
load_inference_model
(
dirname
=
inference_model
,
model_filename
=
model_filename
,
params_filename
=
params_filename
,
executor
=
exe
)
infer_reader
=
get_reader
(
args
.
model_name
.
upper
(),
'infer'
,
infer_config
)
#infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)
#infer_metrics.reset()
periods
=
[]
cur_time
=
time
.
time
()
for
infer_iter
,
data
in
enumerate
(
infer_reader
()):
if
args
.
model_name
==
'EDVR'
:
data_feed_in
=
[
items
[
0
]
for
items
in
data
]
video_info
=
[
items
[
1
:]
for
items
in
data
]
infer_outs
=
exe
.
run
(
inference_program
,
fetch_list
=
fetch_list
,
feed
=
{
feed_list
[
0
]:
np
.
array
(
data_feed_in
)})
infer_result_list
=
[
item
for
item
in
infer_outs
]
videonames
=
[
item
[
0
]
for
item
in
video_info
]
framenames
=
[
item
[
1
]
for
item
in
video_info
]
for
i
in
range
(
len
(
infer_result_list
)):
img_i
=
get_img
(
infer_result_list
[
i
])
save_img
(
img_i
,
'img'
+
videonames
[
i
]
+
framenames
[
i
])
prev_time
=
cur_time
cur_time
=
time
.
time
()
period
=
cur_time
-
prev_time
periods
.
append
(
period
)
#infer_metrics.accumulate(infer_result_list)
if
args
.
log_interval
>
0
and
infer_iter
%
args
.
log_interval
==
0
:
logger
.
info
(
'Processed {} samples'
.
format
(
infer_iter
+
1
))
logger
.
info
(
'[INFER] infer finished. average time: {}'
.
format
(
np
.
mean
(
periods
)))
if
not
os
.
path
.
isdir
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
#infer_metrics.finalize_and_log_out(savedir=args.save_dir)
if
__name__
==
"__main__"
:
args
=
parse_args
()
# check whether the installed paddle is compiled with GPU
check_cuda
(
args
.
use_gpu
)
check_version
()
logger
.
info
(
args
)
infer
(
args
)
applications/EDVR/reader/__init__.py
0 → 100644
浏览文件 @
031e15f1
from
.reader_utils
import
regist_reader
,
get_reader
from
.edvr_reader
import
EDVRReader
regist_reader
(
"EDVR"
,
EDVRReader
)
applications/EDVR/reader/edvr_reader.py
0 → 100644
浏览文件 @
031e15f1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
sys
import
cv2
import
math
import
random
import
multiprocessing
import
functools
import
numpy
as
np
import
paddle
import
cv2
import
logging
from
.reader_utils
import
DataReader
logger
=
logging
.
getLogger
(
__name__
)
python_ver
=
sys
.
version_info
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
class
EDVRReader
(
DataReader
):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def
__init__
(
self
,
name
,
mode
,
cfg
):
super
(
EDVRReader
,
self
).
__init__
(
name
,
mode
,
cfg
)
self
.
format
=
cfg
.
MODEL
.
format
self
.
crop_size
=
self
.
get_config_from_sec
(
mode
,
'crop_size'
)
self
.
interval_list
=
self
.
get_config_from_sec
(
mode
,
'interval_list'
)
self
.
random_reverse
=
self
.
get_config_from_sec
(
mode
,
'random_reverse'
)
self
.
number_frames
=
self
.
get_config_from_sec
(
mode
,
'number_frames'
)
# set batch size and file list
self
.
batch_size
=
cfg
[
mode
.
upper
()][
'batch_size'
]
self
.
fileroot
=
cfg
[
mode
.
upper
()][
'file_root'
]
self
.
use_flip
=
self
.
get_config_from_sec
(
mode
,
'use_flip'
,
False
)
self
.
use_rot
=
self
.
get_config_from_sec
(
mode
,
'use_rot'
,
False
)
self
.
num_reader_threads
=
self
.
get_config_from_sec
(
mode
,
'num_reader_threads'
,
1
)
self
.
buf_size
=
self
.
get_config_from_sec
(
mode
,
'buf_size'
,
1024
)
self
.
fix_random_seed
=
self
.
get_config_from_sec
(
mode
,
'fix_random_seed'
,
False
)
if
self
.
mode
!=
'infer'
:
self
.
gtroot
=
self
.
get_config_from_sec
(
mode
,
'gt_root'
)
self
.
scale
=
self
.
get_config_from_sec
(
mode
,
'scale'
,
1
)
self
.
LR_input
=
(
self
.
scale
>
1
)
if
self
.
fix_random_seed
:
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
num_reader_threads
=
1
def
create_reader
(
self
):
logger
.
info
(
'initialize reader ... '
)
self
.
filelist
=
[]
for
video_name
in
os
.
listdir
(
self
.
fileroot
):
if
(
self
.
mode
==
'train'
)
and
(
video_name
in
[
'000'
,
'011'
,
'015'
,
'020'
]):
continue
for
frame_name
in
os
.
listdir
(
os
.
path
.
join
(
self
.
fileroot
,
video_name
)):
frame_idx
=
frame_name
.
split
(
'.'
)[
0
]
video_frame_idx
=
video_name
+
'_'
+
frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self
.
filelist
.
append
(
video_frame_idx
)
if
self
.
mode
==
'test'
or
self
.
mode
==
'infer'
:
self
.
filelist
.
sort
()
if
self
.
num_reader_threads
==
1
:
reader_func
=
make_reader
else
:
reader_func
=
make_multi_reader
if
self
.
mode
!=
'infer'
:
return
reader_func
(
filelist
=
self
.
filelist
,
num_threads
=
self
.
num_reader_threads
,
batch_size
=
self
.
batch_size
,
is_training
=
(
self
.
mode
==
'train'
),
number_frames
=
self
.
number_frames
,
interval_list
=
self
.
interval_list
,
random_reverse
=
self
.
random_reverse
,
fileroot
=
self
.
fileroot
,
crop_size
=
self
.
crop_size
,
use_flip
=
self
.
use_flip
,
use_rot
=
self
.
use_rot
,
gtroot
=
self
.
gtroot
,
LR_input
=
self
.
LR_input
,
scale
=
self
.
scale
,
mode
=
self
.
mode
)
else
:
return
reader_func
(
filelist
=
self
.
filelist
,
num_threads
=
self
.
num_reader_threads
,
batch_size
=
self
.
batch_size
,
is_training
=
(
self
.
mode
==
'train'
),
number_frames
=
self
.
number_frames
,
interval_list
=
self
.
interval_list
,
random_reverse
=
self
.
random_reverse
,
fileroot
=
self
.
fileroot
,
crop_size
=
self
.
crop_size
,
use_flip
=
self
.
use_flip
,
use_rot
=
self
.
use_rot
,
gtroot
=
''
,
LR_input
=
True
,
scale
=
4
,
mode
=
self
.
mode
)
def
get_sample_data
(
item
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
=
'train'
):
video_name
=
item
.
split
(
'_'
)[
0
]
frame_name
=
item
.
split
(
'_'
)[
1
]
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
ngb_frames
,
name_b
=
get_neighbor_frames
(
frame_name
,
\
number_frames
=
number_frames
,
\
interval_list
=
interval_list
,
\
random_reverse
=
random_reverse
)
elif
(
mode
==
'test'
)
or
(
mode
==
'infer'
):
ngb_frames
,
name_b
=
get_test_neighbor_frames
(
int
(
frame_name
),
number_frames
)
else
:
raise
NotImplementedError
(
'mode {} not implemented'
.
format
(
mode
))
frame_name
=
name_b
print
(
'key2'
,
ngb_frames
,
name_b
)
if
mode
!=
'infer'
:
img_GT
=
read_img
(
os
.
path
.
join
(
gtroot
,
video_name
,
frame_name
+
'.png'
),
is_gt
=
True
)
#print('gt_mean', np.mean(img_GT))
frame_list
=
[]
for
ngb_frm
in
ngb_frames
:
ngb_name
=
"%04d"
%
ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img
=
read_img
(
os
.
path
.
join
(
fileroot
,
video_name
,
ngb_name
+
'.png'
))
frame_list
.
append
(
img
)
#print('img_mean', np.mean(img))
H
,
W
,
C
=
frame_list
[
0
].
shape
# add random crop
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
if
LR_input
:
LQ_size
=
crop_size
//
scale
rnd_h
=
random
.
randint
(
0
,
max
(
0
,
H
-
LQ_size
))
rnd_w
=
random
.
randint
(
0
,
max
(
0
,
W
-
LQ_size
))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list
=
[
v
[
rnd_h
:
rnd_h
+
LQ_size
,
rnd_w
:
rnd_w
+
LQ_size
,
:]
for
v
in
frame_list
]
rnd_h_HR
,
rnd_w_HR
=
int
(
rnd_h
*
scale
),
int
(
rnd_w
*
scale
)
img_GT
=
img_GT
[
rnd_h_HR
:
rnd_h_HR
+
crop_size
,
rnd_w_HR
:
rnd_w_HR
+
crop_size
,
:]
else
:
rnd_h
=
random
.
randint
(
0
,
max
(
0
,
H
-
crop_size
))
rnd_w
=
random
.
randint
(
0
,
max
(
0
,
W
-
crop_size
))
frame_list
=
[
v
[
rnd_h
:
rnd_h
+
crop_size
,
rnd_w
:
rnd_w
+
crop_size
,
:]
for
v
in
frame_list
]
img_GT
=
img_GT
[
rnd_h
:
rnd_h
+
crop_size
,
rnd_w
:
rnd_w
+
crop_size
,
:]
# add random flip and rotation
if
mode
!=
'infer'
:
frame_list
.
append
(
img_GT
)
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
rlt
=
img_augment
(
frame_list
,
use_flip
,
use_rot
)
else
:
rlt
=
frame_list
if
mode
!=
'infer'
:
frame_list
=
rlt
[
0
:
-
1
]
img_GT
=
rlt
[
-
1
]
else
:
frame_list
=
rlt
# stack LQ images to NHWC, N is the frame number
img_LQs
=
np
.
stack
(
frame_list
,
axis
=
0
)
# BGR to RGB, HWC to CHW, numpy to tensor
img_LQs
=
img_LQs
[:,
:,
:,
[
2
,
1
,
0
]]
img_LQs
=
np
.
transpose
(
img_LQs
,
(
0
,
3
,
1
,
2
)).
astype
(
'float32'
)
if
mode
!=
'infer'
:
img_GT
=
img_GT
[:,
:,
[
2
,
1
,
0
]]
img_GT
=
np
.
transpose
(
img_GT
,
(
2
,
0
,
1
)).
astype
(
'float32'
)
return
img_LQs
,
img_GT
else
:
return
img_LQs
def
get_test_neighbor_frames
(
crt_i
,
N
,
max_n
=
100
,
padding
=
'new_info'
):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n
=
max_n
-
1
n_pad
=
N
//
2
return_l
=
[]
for
i
in
range
(
crt_i
-
n_pad
,
crt_i
+
n_pad
+
1
):
if
i
<
0
:
if
padding
==
'replicate'
:
add_idx
=
0
elif
padding
==
'reflection'
:
add_idx
=
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
+
n_pad
)
+
(
-
i
)
elif
padding
==
'circle'
:
add_idx
=
N
+
i
else
:
raise
ValueError
(
'Wrong padding mode'
)
elif
i
>
max_n
:
if
padding
==
'replicate'
:
add_idx
=
max_n
elif
padding
==
'reflection'
:
add_idx
=
max_n
*
2
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
-
n_pad
)
-
(
i
-
max_n
)
elif
padding
==
'circle'
:
add_idx
=
i
-
N
else
:
raise
ValueError
(
'Wrong padding mode'
)
else
:
add_idx
=
i
return_l
.
append
(
add_idx
)
name_b
=
'{:08d}'
.
format
(
crt_i
)
return
return_l
,
name_b
def
get_neighbor_frames
(
frame_name
,
number_frames
,
interval_list
,
random_reverse
,
max_frame
=
99
,
bordermode
=
False
):
center_frame_idx
=
int
(
frame_name
)
half_N_frames
=
number_frames
//
2
#### determine the neighbor frames
interval
=
random
.
choice
(
interval_list
)
if
bordermode
:
direction
=
1
# 1: forward; 0: backward
if
random_reverse
and
random
.
random
()
<
0.5
:
direction
=
random
.
choice
([
0
,
1
])
if
center_frame_idx
+
interval
*
(
number_frames
-
1
)
>
max_frame
:
direction
=
0
elif
center_frame_idx
-
interval
*
(
number_frames
-
1
)
<
0
:
direction
=
1
# get the neighbor list
if
direction
==
1
:
neighbor_list
=
list
(
range
(
center_frame_idx
,
center_frame_idx
+
interval
*
number_frames
,
interval
))
else
:
neighbor_list
=
list
(
range
(
center_frame_idx
,
center_frame_idx
-
interval
*
number_frames
,
-
interval
))
name_b
=
'{:08d}'
.
format
(
neighbor_list
[
0
])
else
:
# ensure not exceeding the borders
while
(
center_frame_idx
+
half_N_frames
*
interval
>
max_frame
)
or
(
center_frame_idx
-
half_N_frames
*
interval
<
0
):
center_frame_idx
=
random
.
randint
(
0
,
max_frame
)
# get the neighbor list
neighbor_list
=
list
(
range
(
center_frame_idx
-
half_N_frames
*
interval
,
center_frame_idx
+
half_N_frames
*
interval
+
1
,
interval
))
if
random_reverse
and
random
.
random
()
<
0.5
:
neighbor_list
.
reverse
()
name_b
=
'{:08d}'
.
format
(
neighbor_list
[
half_N_frames
])
assert
len
(
neighbor_list
)
==
number_frames
,
\
"frames slected have length({}), but it should be ({})"
.
format
(
len
(
neighbor_list
),
number_frames
)
return
neighbor_list
,
name_b
def
read_img
(
path
,
size
=
None
,
is_gt
=
False
):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
#print("path: ", path)
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
def
img_augment
(
img_list
,
hflip
=
True
,
rot
=
True
):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rot
and
random
.
random
()
<
0.5
rot90
=
rot
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
img
=
img
[:,
::
-
1
,
:]
if
vflip
:
img
=
img
[::
-
1
,
:,
:]
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
return
[
_augment
(
img
)
for
img
in
img_list
]
def
make_reader
(
filelist
,
num_threads
,
batch_size
,
is_training
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
=
'train'
):
fl
=
filelist
def
reader_
():
if
is_training
:
random
.
shuffle
(
fl
)
batch_out
=
[]
for
item
in
fl
:
if
mode
!=
'infer'
:
img_LQs
,
img_GT
=
get_sample_data
(
item
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
)
else
:
img_LQs
=
get_sample_data
(
item
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
)
videoname
=
item
.
split
(
'_'
)[
0
]
framename
=
item
.
split
(
'_'
)[
1
]
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
batch_out
.
append
((
img_LQs
,
img_GT
))
elif
mode
==
'test'
:
batch_out
.
append
((
img_LQs
,
img_GT
,
videoname
,
framename
))
elif
mode
==
'infer'
:
batch_out
.
append
((
img_LQs
,
videoname
,
framename
))
else
:
raise
NotImplementedError
(
"mode {} not implemented"
.
format
(
mode
))
if
len
(
batch_out
)
==
batch_size
:
yield
batch_out
batch_out
=
[]
return
reader_
def
make_multi_reader
(
filelist
,
num_threads
,
batch_size
,
is_training
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
=
'train'
):
def
read_into_queue
(
flq
,
queue
):
batch_out
=
[]
for
item
in
flq
:
if
mode
!=
'infer'
:
img_LQs
,
img_GT
=
get_sample_data
(
item
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
)
else
:
img_LQs
=
get_sample_data
(
item
,
number_frames
,
interval_list
,
random_reverse
,
fileroot
,
crop_size
,
use_flip
,
use_rot
,
gtroot
,
LR_input
,
scale
,
mode
)
videoname
=
item
.
split
(
'_'
)[
0
]
framename
=
item
.
split
(
'_'
)[
1
]
if
(
mode
==
'train'
)
or
(
mode
==
'valid'
):
batch_out
.
append
((
img_LQs
,
img_GT
))
elif
mode
==
'test'
:
batch_out
.
append
((
img_LQs
,
img_GT
,
videoname
,
framename
))
elif
mode
==
'infer'
:
batch_out
.
append
((
img_LQs
,
videoname
,
framename
))
else
:
raise
NotImplementedError
(
"mode {} not implemented"
.
format
(
mode
))
if
len
(
batch_out
)
==
batch_size
:
queue
.
put
(
batch_out
)
batch_out
=
[]
queue
.
put
(
None
)
def
queue_reader
():
fl
=
filelist
if
is_training
:
random
.
shuffle
(
fl
)
n
=
num_threads
queue_size
=
20
reader_lists
=
[
None
]
*
n
file_num
=
int
(
len
(
fl
)
//
n
)
for
i
in
range
(
n
):
if
i
<
len
(
reader_lists
)
-
1
:
tmp_list
=
fl
[
i
*
file_num
:(
i
+
1
)
*
file_num
]
else
:
tmp_list
=
fl
[
i
*
file_num
:]
reader_lists
[
i
]
=
tmp_list
queue
=
multiprocessing
.
Queue
(
queue_size
)
p_list
=
[
None
]
*
len
(
reader_lists
)
# for reader_list in reader_lists:
for
i
in
range
(
len
(
reader_lists
)):
reader_list
=
reader_lists
[
i
]
p_list
[
i
]
=
multiprocessing
.
Process
(
target
=
read_into_queue
,
args
=
(
reader_list
,
queue
))
p_list
[
i
].
start
()
reader_num
=
len
(
reader_lists
)
finish_num
=
0
while
finish_num
<
reader_num
:
sample
=
queue
.
get
()
if
sample
is
None
:
finish_num
+=
1
else
:
yield
sample
for
i
in
range
(
len
(
p_list
)):
if
p_list
[
i
].
is_alive
():
p_list
[
i
].
join
()
return
queue_reader
applications/EDVR/reader/reader_utils.py
0 → 100644
浏览文件 @
031e15f1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
pickle
import
cv2
import
numpy
as
np
import
random
class
ReaderNotFoundError
(
Exception
):
"Error: reader not found"
def
__init__
(
self
,
reader_name
,
avail_readers
):
super
(
ReaderNotFoundError
,
self
).
__init__
()
self
.
reader_name
=
reader_name
self
.
avail_readers
=
avail_readers
def
__str__
(
self
):
msg
=
"Reader {} Not Found.
\n
Availiable readers:
\n
"
.
format
(
self
.
reader_name
)
for
reader
in
self
.
avail_readers
:
msg
+=
" {}
\n
"
.
format
(
reader
)
return
msg
class
DataReader
(
object
):
"""data reader for video input"""
def
__init__
(
self
,
model_name
,
mode
,
cfg
):
self
.
name
=
model_name
self
.
mode
=
mode
self
.
cfg
=
cfg
def
create_reader
(
self
):
"""Not implemented"""
pass
def
get_config_from_sec
(
self
,
sec
,
item
,
default
=
None
):
if
sec
.
upper
()
not
in
self
.
cfg
:
return
default
return
self
.
cfg
[
sec
.
upper
()].
get
(
item
,
default
)
class
ReaderZoo
(
object
):
def
__init__
(
self
):
self
.
reader_zoo
=
{}
def
regist
(
self
,
name
,
reader
):
assert
reader
.
__base__
==
DataReader
,
"Unknow model type {}"
.
format
(
type
(
reader
))
self
.
reader_zoo
[
name
]
=
reader
def
get
(
self
,
name
,
mode
,
cfg
):
for
k
,
v
in
self
.
reader_zoo
.
items
():
if
k
==
name
:
return
v
(
name
,
mode
,
cfg
)
raise
ReaderNotFoundError
(
name
,
self
.
reader_zoo
.
keys
())
# singleton reader_zoo
reader_zoo
=
ReaderZoo
()
def
regist_reader
(
name
,
reader
):
reader_zoo
.
regist
(
name
,
reader
)
def
get_reader
(
name
,
mode
,
cfg
):
reader_model
=
reader_zoo
.
get
(
name
,
mode
,
cfg
)
return
reader_model
.
create_reader
()
applications/EDVR/run.sh
0 → 100644
浏览文件 @
031e15f1
# examples of running programs:
# bash ./run.sh inference EDVR ./configs/edvr_L.yaml
# bash ./run.sh predict EDvR ./cofings/edvr_L.yaml
# configs should be ./configs/xxx.yaml
mode
=
$1
name
=
$2
configs
=
$3
save_inference_dir
=
"./data/inference_model"
use_gpu
=
True
fix_random_seed
=
False
log_interval
=
1
valid_interval
=
1
weights
=
"./weights/paddle_state_dict_L.npz"
export
CUDA_VISIBLE_DEVICES
=
4,5,6,7
#0,1,5,6 fast, 2,3,4,7 slow
export
FLAGS_fast_eager_deletion_mode
=
1
export
FLAGS_eager_delete_tensor_gb
=
0.0
export
FLAGS_fraction_of_gpu_memory_to_use
=
0.98
if
[
"
$mode
"
x
==
"predict"
x
]
;
then
echo
$mode
$name
$configs
$weights
if
[
"
$weights
"
x
!=
""
x
]
;
then
python predict.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--video_path
=
''
\
--use_gpu
=
$use_gpu
else
python predict.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--use_gpu
=
$use_gpu
\
--video_path
=
''
fi
fi
applications/EDVR/utils/__init__.py
0 → 100644
浏览文件 @
031e15f1
applications/EDVR/utils/config_utils.py
0 → 100644
浏览文件 @
031e15f1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
yaml
from
.utility
import
AttrDict
import
logging
logger
=
logging
.
getLogger
(
__name__
)
CONFIG_SECS
=
[
'train'
,
'valid'
,
'test'
,
'infer'
,
]
def
parse_config
(
cfg_file
):
"""Load a config file into AttrDict"""
import
yaml
with
open
(
cfg_file
,
'r'
)
as
fopen
:
yaml_config
=
AttrDict
(
yaml
.
load
(
fopen
,
Loader
=
yaml
.
Loader
))
create_attr_dict
(
yaml_config
)
return
yaml_config
def
create_attr_dict
(
yaml_config
):
from
ast
import
literal_eval
for
key
,
value
in
yaml_config
.
items
():
if
type
(
value
)
is
dict
:
yaml_config
[
key
]
=
value
=
AttrDict
(
value
)
if
isinstance
(
value
,
str
):
try
:
value
=
literal_eval
(
value
)
except
BaseException
:
pass
if
isinstance
(
value
,
AttrDict
):
create_attr_dict
(
yaml_config
[
key
])
else
:
yaml_config
[
key
]
=
value
return
def
merge_configs
(
cfg
,
sec
,
args_dict
):
assert
sec
in
CONFIG_SECS
,
"invalid config section {}"
.
format
(
sec
)
sec_dict
=
getattr
(
cfg
,
sec
.
upper
())
for
k
,
v
in
args_dict
.
items
():
if
v
is
None
:
continue
try
:
if
hasattr
(
sec_dict
,
k
):
setattr
(
sec_dict
,
k
,
v
)
except
:
pass
return
cfg
def
print_configs
(
cfg
,
mode
):
logger
.
info
(
"---------------- {:>5} Arguments ----------------"
.
format
(
mode
))
for
sec
,
sec_items
in
cfg
.
items
():
logger
.
info
(
"{}:"
.
format
(
sec
))
for
k
,
v
in
sec_items
.
items
():
logger
.
info
(
" {}:{}"
.
format
(
k
,
v
))
logger
.
info
(
"-------------------------------------------------"
)
applications/EDVR/utils/utility.py
0 → 100644
浏览文件 @
031e15f1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
sys
import
signal
import
logging
import
paddle
import
paddle.fluid
as
fluid
__all__
=
[
'AttrDict'
]
logger
=
logging
.
getLogger
(
__name__
)
def
_term
(
sig_num
,
addition
):
print
(
'current pid is %s, group id is %s'
%
(
os
.
getpid
(),
os
.
getpgrp
()))
os
.
killpg
(
os
.
getpgid
(
os
.
getpid
()),
signal
.
SIGKILL
)
signal
.
signal
(
signal
.
SIGTERM
,
_term
)
signal
.
signal
(
signal
.
SIGINT
,
_term
)
class
AttrDict
(
dict
):
def
__getattr__
(
self
,
key
):
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
if
key
in
self
.
__dict__
:
self
.
__dict__
[
key
]
=
value
else
:
self
[
key
]
=
value
def
check_cuda
(
use_cuda
,
err
=
\
"
\n
You can not set use_gpu = True in the model because you are using paddlepaddle-cpu.
\n
\
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu = False to run models on CPU.
\n
"
):
try
:
if
use_cuda
==
True
and
fluid
.
is_compiled_with_cuda
()
==
False
:
print
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
def
check_version
():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err
=
"PaddlePaddle version 1.6 or higher is required, "
\
"or a suitable develop version is satisfied as well.
\n
"
\
"Please make sure the version is good with your code."
\
try
:
fluid
.
require_version
(
'1.6.0'
)
except
Exception
as
e
:
logger
.
error
(
err
)
sys
.
exit
(
1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录