Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
5973f5f3
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5973f5f3
编写于
8月 12, 2020
作者:
L
lijianshe02
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete inference irralated code
上级
8c46c443
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
6 addition
and
184 deletion
+6
-184
applications/EDVR/eval.py
applications/EDVR/eval.py
+0
-172
applications/EDVR/predict.py
applications/EDVR/predict.py
+6
-11
applications/EDVR/run.sh
applications/EDVR/run.sh
+0
-1
未找到文件。
applications/EDVR/eval.py
已删除
100644 → 0
浏览文件 @
8c46c443
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
sys
import
time
import
logging
import
argparse
import
ast
import
numpy
as
np
import
paddle.fluid
as
fluid
from
utils.config_utils
import
*
import
models
from
reader
import
get_reader
from
metrics
import
get_metrics
from
utils.utility
import
check_cuda
from
utils.utility
import
check_version
logging
.
root
.
handlers
=
[]
FORMAT
=
'[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
,
stream
=
sys
.
stdout
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'AttentionCluster'
,
help
=
'name of model to train.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
default
=
'configs/attention_cluster.txt'
,
help
=
'path to config file of model'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
None
,
help
=
'test batch size. None to use config file setting.'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'default use gpu.'
)
parser
.
add_argument
(
'--weights'
,
type
=
str
,
default
=
None
,
help
=
'weight path, None to automatically download weights provided by Paddle.'
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
os
.
path
.
join
(
'data'
,
'evaluate_results'
),
help
=
'output dir path, default to use ./data/evaluate_results'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
1
,
help
=
'mini-batch interval to log.'
)
args
=
parser
.
parse_args
()
return
args
def
test
(
args
):
# parse config
config
=
parse_config
(
args
.
config
)
test_config
=
merge_configs
(
config
,
'test'
,
vars
(
args
))
print_configs
(
test_config
,
"Test"
)
# build model
test_model
=
models
.
get_model
(
args
.
model_name
,
test_config
,
mode
=
'test'
)
test_model
.
build_input
(
use_dataloader
=
False
)
test_model
.
build_model
()
test_feeds
=
test_model
.
feeds
()
test_fetch_list
=
test_model
.
fetches
()
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
if
args
.
weights
:
assert
os
.
path
.
exists
(
args
.
weights
),
"Given weight dir {} not exist."
.
format
(
args
.
weights
)
weights
=
args
.
weights
or
test_model
.
get_weights
()
logger
.
info
(
'load test weights from {}'
.
format
(
weights
))
test_model
.
load_test_weights
(
exe
,
weights
,
fluid
.
default_main_program
(),
place
)
# get reader and metrics
test_reader
=
get_reader
(
args
.
model_name
.
upper
(),
'test'
,
test_config
)
test_metrics
=
get_metrics
(
args
.
model_name
.
upper
(),
'test'
,
test_config
)
test_feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
test_feeds
)
epoch_period
=
[]
for
test_iter
,
data
in
enumerate
(
test_reader
()):
cur_time
=
time
.
time
()
if
args
.
model_name
==
'ETS'
:
feat_data
=
[
items
[:
3
]
for
items
in
data
]
vinfo
=
[
items
[
3
:]
for
items
in
data
]
test_outs
=
exe
.
run
(
fetch_list
=
test_fetch_list
,
feed
=
test_feeder
.
feed
(
feat_data
),
return_numpy
=
False
)
test_outs
+=
[
vinfo
]
elif
args
.
model_name
==
'TALL'
:
feat_data
=
[
items
[:
2
]
for
items
in
data
]
vinfo
=
[
items
[
2
:]
for
items
in
data
]
test_outs
=
exe
.
run
(
fetch_list
=
test_fetch_list
,
feed
=
test_feeder
.
feed
(
feat_data
),
return_numpy
=
True
)
test_outs
+=
[
vinfo
]
elif
args
.
model_name
==
'EDVR'
:
#img_data = [item[0] for item in data]
#gt_data = [item[1] for item in data]
#gt_data = gt_data[0]
#gt_data = np.transpose(gt_data, (1,2,0))
#gt_data = gt_data[:, :, ::-1]
#print('input', img_data)
#print('gt', gt_data)
feat_data
=
[
items
[:
2
]
for
items
in
data
]
print
(
"feat_data[0] shape: "
,
feat_data
[
0
][
0
].
shape
)
exit
()
vinfo
=
[
items
[
2
:]
for
items
in
data
]
test_outs
=
exe
.
run
(
fetch_list
=
test_fetch_list
,
feed
=
test_feeder
.
feed
(
feat_data
),
return_numpy
=
True
)
#output = test_outs[1]
#print('output', output)
test_outs
+=
[
vinfo
]
else
:
test_outs
=
exe
.
run
(
fetch_list
=
test_fetch_list
,
feed
=
test_feeder
.
feed
(
data
))
period
=
time
.
time
()
-
cur_time
epoch_period
.
append
(
period
)
test_metrics
.
accumulate
(
test_outs
)
# metric here
if
args
.
log_interval
>
0
and
test_iter
%
args
.
log_interval
==
0
:
info_str
=
'[EVAL] Batch {}'
.
format
(
test_iter
)
test_metrics
.
calculate_and_log_out
(
test_outs
,
info_str
)
if
not
os
.
path
.
isdir
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
test_metrics
.
finalize_and_log_out
(
"[EVAL] eval finished. "
,
args
.
save_dir
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
# check whether the installed paddle is compiled with GPU
check_cuda
(
args
.
use_gpu
)
check_version
()
logger
.
info
(
args
)
test
(
args
)
applications/EDVR/predict.py
浏览文件 @
5973f5f3
...
...
@@ -56,12 +56,12 @@ def parse_args():
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'default use gpu.'
)
parser
.
add_argument
(
'--weights'
,
type
=
str
,
default
=
None
,
help
=
'weight path, None to automatically download weights provided by Paddle.'
)
#
parser.add_argument(
#
'--weights',
#
type=str,
#
default=None,
#
help='weight path, None to automatically download weights provided by Paddle.'
#
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
...
...
@@ -117,11 +117,6 @@ def infer(args):
config
=
parse_config
(
args
.
config
)
infer_config
=
merge_configs
(
config
,
'infer'
,
vars
(
args
))
print_configs
(
infer_config
,
"Infer"
)
#infer_model = models.get_model(args.model_name, infer_config, mode='infer')
#infer_model.build_input(use_dataloader=False)
#infer_model.build_model()
#infer_feeds = infer_model.feeds()
#infer_outputs = infer_model.outputs()
model_path
=
'/workspace/video_test/video/for_eval/data/inference_model'
model_filename
=
'EDVR_model.pdmodel'
...
...
applications/EDVR/run.sh
浏览文件 @
5973f5f3
...
...
@@ -34,7 +34,6 @@ if [ "$mode"x == "predict"x ]; then
python predict.py
--model_name
=
$name
\
--config
=
$configs
\
--log_interval
=
$log_interval
\
--weights
=
$weights
\
--video_path
=
''
\
--use_gpu
=
$use_gpu
else
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录