Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
e0b40b36
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e0b40b36
编写于
8月 12, 2020
作者:
L
lijianshe02
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete inference irralated code
上级
5973f5f3
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
3 addition
and
273 deletion
+3
-273
applications/EDVR/configs/edvr_L.yaml
applications/EDVR/configs/edvr_L.yaml
+1
-13
applications/EDVR/inference_model.py
applications/EDVR/inference_model.py
+1
-1
applications/EDVR/metrics/__init__.py
applications/EDVR/metrics/__init__.py
+0
-1
applications/EDVR/metrics/edvr_metrics/__init__.py
applications/EDVR/metrics/edvr_metrics/__init__.py
+0
-0
applications/EDVR/metrics/edvr_metrics/edvr_metrics.py
applications/EDVR/metrics/edvr_metrics/edvr_metrics.py
+0
-145
applications/EDVR/metrics/metrics_util.py
applications/EDVR/metrics/metrics_util.py
+0
-106
applications/EDVR/predict.py
applications/EDVR/predict.py
+1
-7
未找到文件。
applications/EDVR/configs/edvr_L.yaml
浏览文件 @
e0b40b36
...
...
@@ -11,18 +11,6 @@ MODEL:
HR_in
:
False
w_TSA
:
True
#False
TEST
:
scale
:
4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
5
batch_size
:
1
file_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip
:
False
use_rot
:
False
INFER
:
scale
:
4
crop_size
:
256
...
...
@@ -31,6 +19,6 @@ INFER:
number_frames
:
5
batch_size
:
1
file_root
:
"
/workspace/color/input_frames"
gt_root
:
"
/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
#
gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip
:
False
use_rot
:
False
applications/EDVR/inference_model.py
浏览文件 @
e0b40b36
...
...
@@ -28,7 +28,7 @@ import paddle.fluid as fluid
from
utils.config_utils
import
*
import
models
from
reader
import
get_reader
from
metrics
import
get_metrics
#
from metrics import get_metrics
from
utils.utility
import
check_cuda
logging
.
root
.
handlers
=
[]
...
...
applications/EDVR/metrics/__init__.py
已删除
100644 → 0
浏览文件 @
5973f5f3
from
.metrics_util
import
get_metrics
applications/EDVR/metrics/edvr_metrics/__init__.py
已删除
100644 → 0
浏览文件 @
5973f5f3
applications/EDVR/metrics/edvr_metrics/edvr_metrics.py
已删除
100644 → 0
浏览文件 @
5973f5f3
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import
numpy
as
np
import
datetime
import
logging
import
json
import
os
import
cv2
import
math
logger
=
logging
.
getLogger
(
__name__
)
class
MetricsCalculator
():
def
__init__
(
self
,
name
=
'EDVR'
,
mode
=
'train'
):
self
.
name
=
name
self
.
mode
=
mode
# 'train', 'valid', 'test', 'infer'
self
.
reset
()
self
.
total_frames
=
9002
#100
self
.
bolder_frames
=
2
def
reset
(
self
):
logger
.
info
(
'Resetting {} metrics...'
.
format
(
self
.
mode
))
if
(
self
.
mode
==
'train'
)
or
(
self
.
mode
==
'valid'
):
self
.
aggr_loss
=
0.0
elif
(
self
.
mode
==
'test'
)
or
(
self
.
mode
==
'infer'
):
self
.
result_dict
=
dict
()
def
calculate_and_logout
(
self
,
fetch_list
,
info
):
pass
def
accumulate
(
self
,
fetch_list
):
loss
=
fetch_list
[
0
]
pred
=
fetch_list
[
1
]
gt
=
fetch_list
[
2
]
videoinfo
=
fetch_list
[
-
1
]
print
(
'videoinfo: '
,
videoinfo
)
videonames
=
[
item
[
0
]
for
item
in
videoinfo
]
framenames
=
[
item
[
1
]
for
item
in
videoinfo
]
for
i
in
range
(
len
(
pred
)):
pred_i
=
pred
[
i
]
gt_i
=
gt
[
i
]
videoname_i
=
videonames
[
i
]
framename_i
=
framenames
[
i
]
if
videoname_i
not
in
self
.
result_dict
.
keys
():
self
.
result_dict
[
videoname_i
]
=
{}
if
framename_i
in
self
.
result_dict
[
videoname_i
].
keys
():
logger
.
info
(
"frame {} already processed in video {}, please check it"
.
format
(
framename_i
,
videoname_i
))
raise
is_bolder
=
(
int
(
framename_i
)
>
(
self
.
total_frames
-
self
.
bolder_frames
-
1
)
or
int
(
framename_i
)
<
self
.
bolder_frames
)
psnr_i
=
get_psnr
(
pred_i
,
gt_i
)
img_i
=
get_img
(
pred_i
)
self
.
result_dict
[
videoname_i
][
framename_i
]
=
[
is_bolder
,
psnr_i
]
is_save
=
True
if
is_save
and
(
i
==
len
(
pred
)
-
1
):
save_img
(
img_i
,
framename_i
)
logger
.
info
(
"video {}, frame {}, bolder {}, psnr = {}"
.
format
(
videoname_i
,
framename_i
,
is_bolder
,
psnr_i
))
def
finalize_metrics
(
self
,
savedir
):
avg_psnr
=
0.
avg_psnr_center
=
0.
avg_psnr_bolder
=
0.
center_num
=
0.
bolder_num
=
0.
for
videoname
in
self
.
result_dict
.
keys
():
videoresult
=
self
.
result_dict
[
videoname
]
framelist
=
list
(
videoresult
.
keys
())
video_psnr_center
=
0.
video_psnr_bolder
=
0.
video_center_num
=
0.
video_bolder_num
=
0.
for
frame
in
framelist
:
frameresult
=
videoresult
[
frame
]
is_bolder
=
frameresult
[
0
]
psnr
=
frameresult
[
1
]
if
is_bolder
:
video_bolder_num
+=
1
video_psnr_bolder
+=
psnr
else
:
video_center_num
+=
1
video_psnr_center
+=
psnr
video_num
=
video_bolder_num
+
video_center_num
video_psnr
=
video_psnr_center
+
video_psnr_bolder
avg_psnr_bolder
+=
video_psnr_bolder
avg_psnr_center
+=
video_psnr_center
bolder_num
+=
video_bolder_num
center_num
+=
video_center_num
logger
.
info
(
"video {}, total frame num/psnr {}/{}, center num/psnr {}/{}, bolder num/psnr {}/{}"
.
format
(
videoname
,
video_num
,
video_psnr
/
video_num
,
video_center_num
,
video_psnr_center
/
video_center_num
,
video_bolder_num
,
video_psnr_bolder
/
video_bolder_num
))
avg_psnr
=
avg_psnr_bolder
+
avg_psnr_center
total_num
=
bolder_num
+
center_num
avg_psnr
=
avg_psnr
/
total_num
avg_psnr_center
=
avg_psnr_center
/
center_num
avg_psnr_bolder
=
avg_psnr_bolder
/
bolder_num
logger
.
info
(
"Average psnr {}, center {}, bolder {}"
.
format
(
avg_psnr
,
avg_psnr_center
,
avg_psnr_bolder
))
def
get_psnr
(
pred
,
gt
):
# pred and gt have range [0, 1]
pred
=
pred
.
squeeze
().
astype
(
np
.
float64
)
pred
=
pred
*
255.
pred
=
pred
.
round
()
gt
=
gt
.
squeeze
().
astype
(
np
.
float64
)
gt
=
gt
*
255.
gt
=
gt
.
round
()
mse
=
np
.
mean
((
pred
-
gt
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
20
*
math
.
log10
(
255.0
/
math
.
sqrt
(
mse
))
def
get_img
(
pred
):
print
(
'pred shape'
,
pred
.
shape
)
pred
=
pred
.
squeeze
()
pred
=
np
.
clip
(
pred
,
a_min
=
0.
,
a_max
=
1.0
)
pred
=
pred
*
255
pred
=
pred
.
round
()
pred
=
pred
.
astype
(
'uint8'
)
pred
=
np
.
transpose
(
pred
,
(
1
,
2
,
0
))
# chw -> hwc
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
return
pred
def
save_img
(
img
,
framename
):
dirname
=
'./demo/resultpng'
filename
=
os
.
path
.
join
(
dirname
,
framename
+
'.png'
)
cv2
.
imwrite
(
filename
,
img
)
applications/EDVR/metrics/metrics_util.py
已删除
100644 → 0
浏览文件 @
5973f5f3
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
unicode_literals
from
__future__
import
print_function
from
__future__
import
division
import
logging
import
numpy
as
np
import
json
from
metrics.edvr_metrics
import
edvr_metrics
as
edvr_metrics
logger
=
logging
.
getLogger
(
__name__
)
class
Metrics
(
object
):
def
__init__
(
self
,
name
,
mode
,
metrics_args
):
"""Not implemented"""
pass
def
calculate_and_log_out
(
self
,
fetch_list
,
info
=
''
):
"""Not implemented"""
pass
def
accumulate
(
self
,
fetch_list
,
info
=
''
):
"""Not implemented"""
pass
def
finalize_and_log_out
(
self
,
info
=
''
,
savedir
=
'./'
):
"""Not implemented"""
pass
def
reset
(
self
):
"""Not implemented"""
pass
class
EDVRMetrics
(
Metrics
):
def
__init__
(
self
,
name
,
mode
,
cfg
):
self
.
name
=
name
self
.
mode
=
mode
args
=
{}
args
[
'mode'
]
=
mode
args
[
'name'
]
=
name
self
.
calculator
=
edvr_metrics
.
MetricsCalculator
(
**
args
)
def
calculate_and_log_out
(
self
,
fetch_list
,
info
=
''
):
if
(
self
.
mode
==
'train'
)
or
(
self
.
mode
==
'valid'
):
loss
=
np
.
array
(
fetch_list
[
0
])
logger
.
info
(
info
+
'
\t
Loss = {}'
.
format
(
'%.04f'
%
np
.
mean
(
loss
)))
elif
self
.
mode
==
'test'
:
pass
def
accumulate
(
self
,
fetch_list
):
self
.
calculator
.
accumulate
(
fetch_list
)
def
finalize_and_log_out
(
self
,
info
=
''
,
savedir
=
'./'
):
self
.
calculator
.
finalize_metrics
(
savedir
)
def
reset
(
self
):
self
.
calculator
.
reset
()
class
MetricsZoo
(
object
):
def
__init__
(
self
):
self
.
metrics_zoo
=
{}
def
regist
(
self
,
name
,
metrics
):
assert
metrics
.
__base__
==
Metrics
,
"Unknow model type {}"
.
format
(
type
(
metrics
))
self
.
metrics_zoo
[
name
]
=
metrics
def
get
(
self
,
name
,
mode
,
cfg
):
for
k
,
v
in
self
.
metrics_zoo
.
items
():
if
k
==
name
:
return
v
(
name
,
mode
,
cfg
)
raise
MetricsNotFoundError
(
name
,
self
.
metrics_zoo
.
keys
())
# singleton metrics_zoo
metrics_zoo
=
MetricsZoo
()
def
regist_metrics
(
name
,
metrics
):
metrics_zoo
.
regist
(
name
,
metrics
)
def
get_metrics
(
name
,
mode
,
cfg
):
return
metrics_zoo
.
get
(
name
,
mode
,
cfg
)
# sort by alphabet
regist_metrics
(
"EDVR"
,
EDVRMetrics
)
applications/EDVR/predict.py
浏览文件 @
e0b40b36
...
...
@@ -29,7 +29,7 @@ import cv2
from
utils.config_utils
import
*
import
models
from
reader
import
get_reader
from
metrics
import
get_metrics
#
from metrics import get_metrics
from
utils.utility
import
check_cuda
from
utils.utility
import
check_version
...
...
@@ -56,12 +56,6 @@ def parse_args():
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'default use gpu.'
)
# parser.add_argument(
# '--weights',
# type=str,
# default=None,
# help='weight path, None to automatically download weights provided by Paddle.'
# )
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录