Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
542c8839
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
542c8839
编写于
8月 12, 2020
作者:
L
lijianshe02
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support variable length input
上级
cac2fbdf
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
4 addition
and
1050 deletion
+4
-1050
applications/EDVR/configs/edvr_L.yaml
applications/EDVR/configs/edvr_L.yaml
+1
-0
applications/EDVR/inference_model.py
applications/EDVR/inference_model.py
+0
-123
applications/EDVR/models/__init__.py
applications/EDVR/models/__init__.py
+0
-4
applications/EDVR/models/edvr/__init__.py
applications/EDVR/models/edvr/__init__.py
+0
-1
applications/EDVR/models/edvr/edvr.py
applications/EDVR/models/edvr/edvr.py
+0
-265
applications/EDVR/models/edvr/edvr_model.py
applications/EDVR/models/edvr/edvr_model.py
+0
-417
applications/EDVR/models/model.py
applications/EDVR/models/model.py
+0
-191
applications/EDVR/models/utils.py
applications/EDVR/models/utils.py
+0
-47
applications/EDVR/predict.py
applications/EDVR/predict.py
+2
-2
applications/EDVR/reader/edvr_reader.py
applications/EDVR/reader/edvr_reader.py
+1
-0
未找到文件。
applications/EDVR/configs/edvr_L.yaml
浏览文件 @
542c8839
...
...
@@ -19,6 +19,7 @@ INFER:
number_frames
:
5
batch_size
:
1
file_root
:
"
/workspace/color/input_frames"
#file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
#gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip
:
False
use_rot
:
False
applications/EDVR/inference_model.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
sys
import
time
import
logging
import
argparse
import
ast
import
numpy
as
np
try
:
import
cPickle
as
pickle
except
:
import
pickle
import
paddle.fluid
as
fluid
from
utils.config_utils
import
*
import
models
from
reader
import
get_reader
#from metrics import get_metrics
from
utils.utility
import
check_cuda
logging
.
root
.
handlers
=
[]
FORMAT
=
'[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
FORMAT
,
stream
=
sys
.
stdout
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'AttentionCluster'
,
help
=
'name of model to train.'
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
default
=
'configs/attention_cluster.txt'
,
help
=
'path to config file of model'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'default use gpu.'
)
parser
.
add_argument
(
'--weights'
,
type
=
str
,
default
=
None
,
help
=
'weight path, None to automatically download weights provided by Paddle.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'sample number in a batch for inference.'
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'./'
,
help
=
'directory to store model and params file'
)
args
=
parser
.
parse_args
()
return
args
def
save_inference_model
(
args
):
# parse config
config
=
parse_config
(
args
.
config
)
infer_config
=
merge_configs
(
config
,
'infer'
,
vars
(
args
))
print_configs
(
infer_config
,
"Infer"
)
infer_model
=
models
.
get_model
(
args
.
model_name
,
infer_config
,
mode
=
'infer'
)
infer_model
.
build_input
(
use_dataloader
=
False
)
infer_model
.
build_model
()
infer_feeds
=
infer_model
.
feeds
()
infer_outputs
=
infer_model
.
outputs
()
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
if
args
.
weights
:
assert
os
.
path
.
exists
(
args
.
weights
),
"Given weight dir {} not exist."
.
format
(
args
.
weights
)
# if no weight files specified, download weights from paddle
weights
=
args
.
weights
or
infer_model
.
get_weights
()
infer_model
.
load_test_weights
(
exe
,
weights
,
fluid
.
default_main_program
(),
place
)
if
not
os
.
path
.
isdir
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
# saving inference model
fluid
.
io
.
save_inference_model
(
args
.
save_dir
,
feeded_var_names
=
[
item
.
name
for
item
in
infer_feeds
],
target_vars
=
infer_outputs
,
executor
=
exe
,
main_program
=
fluid
.
default_main_program
(),
model_filename
=
args
.
model_name
+
"_model.pdmodel"
,
params_filename
=
args
.
model_name
+
"_params.pdparams"
)
print
(
"save inference model at %s"
%
(
args
.
save_dir
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
# check whether the installed paddle is compiled with GPU
check_cuda
(
args
.
use_gpu
)
logger
.
info
(
args
)
save_inference_model
(
args
)
applications/EDVR/models/__init__.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
from
.model
import
regist_model
,
get_model
from
.edvr
import
EDVR
regist_model
(
"EDVR"
,
EDVR
)
applications/EDVR/models/edvr/__init__.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
from
.edvr
import
*
applications/EDVR/models/edvr/edvr.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
math
import
paddle.fluid
as
fluid
from
paddle.fluid
import
ParamAttr
from
..model
import
ModelBase
from
.edvr_model
import
EDVRModel
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"EDVR"
]
class
EDVR
(
ModelBase
):
def
__init__
(
self
,
name
,
cfg
,
mode
=
'train'
):
super
(
EDVR
,
self
).
__init__
(
name
,
cfg
,
mode
=
mode
)
self
.
get_config
()
def
get_config
(
self
):
self
.
num_filters
=
self
.
get_config_from_sec
(
'model'
,
'num_filters'
)
self
.
num_frames
=
self
.
get_config_from_sec
(
'model'
,
'num_frames'
)
self
.
dcn_groups
=
self
.
get_config_from_sec
(
'model'
,
'deform_conv_groups'
)
self
.
front_RBs
=
self
.
get_config_from_sec
(
'model'
,
'front_RBs'
)
self
.
back_RBs
=
self
.
get_config_from_sec
(
'model'
,
'back_RBs'
)
self
.
center
=
self
.
get_config_from_sec
(
'model'
,
'center'
,
2
)
self
.
predeblur
=
self
.
get_config_from_sec
(
'model'
,
'predeblur'
,
False
)
self
.
HR_in
=
self
.
get_config_from_sec
(
'model'
,
'HR_in'
,
False
)
self
.
w_TSA
=
self
.
get_config_from_sec
(
'model'
,
'w_TSA'
,
True
)
self
.
crop_size
=
self
.
get_config_from_sec
(
self
.
mode
,
'crop_size'
)
self
.
scale
=
self
.
get_config_from_sec
(
self
.
mode
,
'scale'
,
1
)
self
.
num_gpus
=
self
.
get_config_from_sec
(
self
.
mode
,
'num_gpus'
,
8
)
self
.
batch_size
=
self
.
get_config_from_sec
(
self
.
mode
,
'batch_size'
,
256
)
# get optimizer related parameters
self
.
base_learning_rate
=
self
.
get_config_from_sec
(
'train'
,
'base_learning_rate'
)
self
.
l2_weight_decay
=
self
.
get_config_from_sec
(
'train'
,
'l2_weight_decay'
)
self
.
T_periods
=
self
.
get_config_from_sec
(
'train'
,
'T_periods'
)
self
.
restarts
=
self
.
get_config_from_sec
(
'train'
,
'restarts'
)
self
.
weights
=
self
.
get_config_from_sec
(
'train'
,
'weights'
)
self
.
eta_min
=
self
.
get_config_from_sec
(
'train'
,
'eta_min'
)
self
.
TSA_only
=
self
.
get_config_from_sec
(
'train'
,
'TSA_only'
,
False
)
def
build_input
(
self
,
use_dataloader
=
True
):
if
self
.
mode
!=
'test'
:
gt_shape
=
[
None
,
3
,
self
.
crop_size
,
self
.
crop_size
]
else
:
gt_shape
=
[
None
,
3
,
720
,
1280
]
if
self
.
HR_in
:
img_shape
=
[
-
1
,
self
.
num_frames
,
3
,
self
.
crop_size
,
self
.
crop_size
]
else
:
if
(
self
.
mode
!=
'test'
)
and
(
self
.
mode
!=
'infer'
)
:
img_shape
=
[
None
,
self
.
num_frames
,
3
,
\
int
(
self
.
crop_size
/
self
.
scale
),
int
(
self
.
crop_size
/
self
.
scale
)]
else
:
img_shape
=
[
None
,
self
.
num_frames
,
3
,
360
,
472
]
#180, 320]
self
.
use_dataloader
=
use_dataloader
image
=
fluid
.
data
(
name
=
'LQ_IMGs'
,
shape
=
img_shape
,
dtype
=
'float32'
)
if
self
.
mode
!=
'infer'
:
label
=
fluid
.
data
(
name
=
'GT_IMG'
,
shape
=
gt_shape
,
dtype
=
'float32'
)
else
:
label
=
None
if
use_dataloader
:
assert
self
.
mode
!=
'infer'
,
\
'dataloader is not recommendated when infer, please set use_dataloader to be false.'
self
.
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
capacity
=
4
,
iterable
=
True
)
self
.
feature_input
=
[
image
]
self
.
label_input
=
label
def
create_model_args
(
self
):
cfg
=
{}
cfg
[
'nf'
]
=
self
.
num_filters
cfg
[
'nframes'
]
=
self
.
num_frames
cfg
[
'groups'
]
=
self
.
dcn_groups
cfg
[
'front_RBs'
]
=
self
.
front_RBs
cfg
[
'back_RBs'
]
=
self
.
back_RBs
cfg
[
'center'
]
=
self
.
center
cfg
[
'predeblur'
]
=
self
.
predeblur
cfg
[
'HR_in'
]
=
self
.
HR_in
cfg
[
'w_TSA'
]
=
self
.
w_TSA
cfg
[
'mode'
]
=
self
.
mode
cfg
[
'TSA_only'
]
=
self
.
TSA_only
return
cfg
def
build_model
(
self
):
cfg
=
self
.
create_model_args
()
videomodel
=
EDVRModel
(
**
cfg
)
out
=
videomodel
.
net
(
self
.
feature_input
[
0
])
self
.
network_outputs
=
[
out
]
def
optimizer
(
self
):
assert
self
.
mode
==
'train'
,
"optimizer only can be get in train mode"
learning_rate
=
get_lr
(
base_lr
=
self
.
base_learning_rate
,
T_periods
=
self
.
T_periods
,
restarts
=
self
.
restarts
,
weights
=
self
.
weights
,
eta_min
=
self
.
eta_min
)
l2_weight_decay
=
self
.
l2_weight_decay
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
beta1
=
0.9
,
beta2
=
0.99
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
l2_weight_decay
))
return
optimizer
def
loss
(
self
):
assert
self
.
mode
!=
'infer'
,
"invalid loss calculationg in infer mode"
pred
=
self
.
network_outputs
[
0
]
label
=
self
.
label_input
epsilon
=
1e-6
diff
=
pred
-
label
diff
=
diff
*
diff
+
epsilon
diff
=
fluid
.
layers
.
sqrt
(
diff
)
self
.
loss_
=
fluid
.
layers
.
reduce_sum
(
diff
)
return
self
.
loss_
def
outputs
(
self
):
return
self
.
network_outputs
def
feeds
(
self
):
return
self
.
feature_input
if
self
.
mode
==
'infer'
else
self
.
feature_input
+
[
self
.
label_input
]
def
fetches
(
self
):
if
self
.
mode
==
'train'
or
self
.
mode
==
'valid'
:
losses
=
self
.
loss
()
fetch_list
=
[
losses
,
self
.
network_outputs
[
0
],
self
.
label_input
]
elif
self
.
mode
==
'test'
:
losses
=
self
.
loss
()
fetch_list
=
[
losses
,
self
.
network_outputs
[
0
],
self
.
label_input
]
elif
self
.
mode
==
'infer'
:
fetch_list
=
self
.
network_outputs
else
:
raise
NotImplementedError
(
'mode {} not implemented'
.
format
(
self
.
mode
))
return
fetch_list
def
pretrain_info
(
self
):
return
(
None
,
None
)
def
weights_info
(
self
):
return
(
None
,
None
)
def
load_pretrain_params0
(
self
,
exe
,
pretrain
,
prog
,
place
):
"""load pretrain form .npz which is created by torch"""
def
is_parameter
(
var
):
return
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
params_list
=
list
(
filter
(
is_parameter
,
prog
.
list_vars
()))
import
numpy
as
np
state_dict
=
np
.
load
(
pretrain
)
for
p
in
params_list
:
if
p
.
name
in
state_dict
.
keys
():
print
(
'########### load param {} from file'
.
format
(
p
.
name
))
else
:
print
(
'----------- param {} not in file'
.
format
(
p
.
name
))
fluid
.
set_program_state
(
prog
,
state_dict
)
print
(
'load pretrain from '
,
pretrain
)
def
load_test_weights
(
self
,
exe
,
weights
,
prog
,
place
):
"""load weights from .npz which is created by torch"""
def
is_parameter
(
var
):
return
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
params_list
=
list
(
filter
(
is_parameter
,
prog
.
list_vars
()))
import
numpy
as
np
state_dict
=
np
.
load
(
weights
)
for
p
in
params_list
:
if
p
.
name
in
state_dict
.
keys
():
print
(
'########### load param {} from file'
.
format
(
p
.
name
))
else
:
print
(
'----------- param {} not in file'
.
format
(
p
.
name
))
fluid
.
set_program_state
(
prog
,
state_dict
)
print
(
'load weights from '
,
weights
)
# This is for learning rate cosine annealing restart
Dtype
=
'float32'
def
decay_step_counter
(
begin
=
0
):
# the first global step is zero in learning rate decay
global_step
=
fluid
.
layers
.
autoincreased_step_counter
(
counter_name
=
'@LR_DECAY_COUNTER@'
,
begin
=
begin
,
step
=
1
)
return
global_step
def
get_lr
(
base_lr
=
0.001
,
T_periods
=
[
250000
,
250000
,
250000
,
250000
],
restarts
=
[
250000
,
500000
,
750000
],
weights
=
[
1
,
1
,
1
],
eta_min
=
0
):
with
fluid
.
default_main_program
().
_lr_schedule_guard
():
global_step
=
decay_step_counter
()
lr
=
fluid
.
layers
.
create_global_var
(
shape
=
[
1
],
value
=
base_lr
,
dtype
=
Dtype
,
persistable
=
True
,
name
=
"learning_rate"
)
num_segs
=
len
(
restarts
)
restart_point
=
0
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
==
0
):
pass
for
i
in
range
(
num_segs
):
T_max
=
T_periods
[
i
]
weight
=
weights
[
i
]
with
switch
.
case
(
global_step
<
restarts
[
i
]):
with
fluid
.
layers
.
Switch
()
as
switch_second
:
value_2Tmax
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
2
*
T_max
)
step_checker
=
global_step
-
restart_point
-
1
-
T_max
with
switch_second
.
case
(
fluid
.
layers
.
elementwise_mod
(
step_checker
,
value_2Tmax
)
==
0
):
var_value
=
lr
+
(
base_lr
-
eta_min
)
*
(
1
-
math
.
cos
(
math
.
pi
/
float
(
T_max
)))
/
2
fluid
.
layers
.
assign
(
var_value
,
lr
)
with
switch_second
.
default
():
double_step
=
fluid
.
layers
.
cast
(
global_step
,
dtype
=
'float64'
)
-
float
(
restart_point
)
double_scale
=
(
1
+
fluid
.
layers
.
cos
(
math
.
pi
*
double_step
/
float
(
T_max
)))
/
\
(
1
+
fluid
.
layers
.
cos
(
math
.
pi
*
(
double_step
-
1
)
/
float
(
T_max
)))
float_scale
=
fluid
.
layers
.
cast
(
double_scale
,
dtype
=
Dtype
)
var_value
=
float_scale
*
(
lr
-
eta_min
)
+
eta_min
fluid
.
layers
.
assign
(
var_value
,
lr
)
with
switch
.
case
(
global_step
==
restarts
[
i
]):
var_value
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
Dtype
,
value
=
float
(
base_lr
*
weight
))
fluid
.
layers
.
assign
(
var_value
,
lr
)
restart_point
=
restarts
[
i
]
T_max
=
T_periods
[
num_segs
]
with
switch
.
default
():
with
fluid
.
layers
.
Switch
()
as
switch_second
:
value_2Tmax
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
2
*
T_max
)
step_checker
=
global_step
-
restart_point
-
1
-
T_max
with
switch_second
.
case
(
fluid
.
layers
.
elementwise_mod
(
step_checker
,
value_2Tmax
)
==
0
):
var_value
=
lr
+
(
base_lr
-
eta_min
)
*
(
1
-
math
.
cos
(
math
.
pi
/
float
(
T_max
)))
/
2
fluid
.
layers
.
assign
(
var_value
,
lr
)
with
switch_second
.
default
():
double_step
=
fluid
.
layers
.
cast
(
global_step
,
dtype
=
'float64'
)
-
float
(
restart_point
)
double_scale
=
(
1
+
fluid
.
layers
.
cos
(
math
.
pi
*
double_step
/
float
(
T_max
)))
/
\
(
1
+
fluid
.
layers
.
cos
(
math
.
pi
*
(
double_step
-
1
)
/
float
(
T_max
)))
float_scale
=
fluid
.
layers
.
cast
(
double_scale
,
dtype
=
Dtype
)
var_value
=
float_scale
*
(
lr
-
eta_min
)
+
eta_min
fluid
.
layers
.
assign
(
var_value
,
lr
)
return
lr
applications/EDVR/models/edvr/edvr_model.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
此差异已折叠。
点击以展开。
applications/EDVR/models/model.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
wget
import
logging
try
:
from
configparser
import
ConfigParser
except
:
from
ConfigParser
import
ConfigParser
import
paddle.fluid
as
fluid
from
.utils
import
download
,
AttrDict
WEIGHT_DIR
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'.paddle'
,
'weights'
)
logger
=
logging
.
getLogger
(
__name__
)
def
is_parameter
(
var
):
return
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
class
NotImplementError
(
Exception
):
"Error: model function not implement"
def
__init__
(
self
,
model
,
function
):
super
(
NotImplementError
,
self
).
__init__
()
self
.
model
=
model
.
__class__
.
__name__
self
.
function
=
function
.
__name__
def
__str__
(
self
):
return
"Function {}() is not implemented in model {}"
.
format
(
self
.
function
,
self
.
model
)
class
ModelNotFoundError
(
Exception
):
"Error: model not found"
def
__init__
(
self
,
model_name
,
avail_models
):
super
(
ModelNotFoundError
,
self
).
__init__
()
self
.
model_name
=
model_name
self
.
avail_models
=
avail_models
def
__str__
(
self
):
msg
=
"Model {} Not Found.
\n
Availiable models:
\n
"
.
format
(
self
.
model_name
)
for
model
in
self
.
avail_models
:
msg
+=
" {}
\n
"
.
format
(
model
)
return
msg
class
ModelBase
(
object
):
def
__init__
(
self
,
name
,
cfg
,
mode
=
'train'
):
assert
mode
in
[
'train'
,
'valid'
,
'test'
,
'infer'
],
\
"Unknown mode type {}"
.
format
(
mode
)
self
.
name
=
name
self
.
is_training
=
(
mode
==
'train'
)
self
.
mode
=
mode
self
.
cfg
=
cfg
self
.
dataloader
=
None
def
build_model
(
self
):
"build model struct"
raise
NotImplementError
(
self
,
self
.
build_model
)
def
build_input
(
self
,
use_dataloader
):
"build input Variable"
raise
NotImplementError
(
self
,
self
.
build_input
)
def
optimizer
(
self
):
"get model optimizer"
raise
NotImplementError
(
self
,
self
.
optimizer
)
def
outputs
():
"get output variable"
raise
notimplementerror
(
self
,
self
.
outputs
)
def
loss
(
self
):
"get loss variable"
raise
notimplementerror
(
self
,
self
.
loss
)
def
feeds
(
self
):
"get feed inputs list"
raise
NotImplementError
(
self
,
self
.
feeds
)
def
fetches
(
self
):
"get fetch list of model"
raise
NotImplementError
(
self
,
self
.
fetches
)
def
weights_info
(
self
):
"get model weight default path and download url"
raise
NotImplementError
(
self
,
self
.
weights_info
)
def
get_weights
(
self
):
"get model weight file path, download weight from Paddle if not exist"
path
,
url
=
self
.
weights_info
()
path
=
os
.
path
.
join
(
WEIGHT_DIR
,
path
)
if
not
os
.
path
.
isdir
(
WEIGHT_DIR
):
logger
.
info
(
'{} not exists, will be created automatically.'
.
format
(
WEIGHT_DIR
))
os
.
makedirs
(
WEIGHT_DIR
)
if
os
.
path
.
exists
(
path
):
return
path
logger
.
info
(
"Download weights of {} from {}"
.
format
(
self
.
name
,
url
))
wget
.
download
(
url
,
path
)
return
path
def
dataloader
(
self
):
return
self
.
dataloader
def
epoch_num
(
self
):
"get train epoch num"
return
self
.
cfg
.
TRAIN
.
epoch
def
pretrain_info
(
self
):
"get pretrain base model directory"
return
(
None
,
None
)
def
get_pretrain_weights
(
self
):
"get model weight file path, download weight from Paddle if not exist"
path
,
url
=
self
.
pretrain_info
()
if
not
path
:
return
None
path
=
os
.
path
.
join
(
WEIGHT_DIR
,
path
)
if
not
os
.
path
.
isdir
(
WEIGHT_DIR
):
logger
.
info
(
'{} not exists, will be created automatically.'
.
format
(
WEIGHT_DIR
))
os
.
makedirs
(
WEIGHT_DIR
)
if
os
.
path
.
exists
(
path
):
return
path
logger
.
info
(
"Download pretrain weights of {} from {}"
.
format
(
self
.
name
,
url
))
download
(
url
,
path
)
return
path
def
load_pretrain_params
(
self
,
exe
,
pretrain
,
prog
,
place
):
logger
.
info
(
"Load pretrain weights from {}"
.
format
(
pretrain
))
state_dict
=
fluid
.
load_program_state
(
pretrain
)
fluid
.
set_program_state
(
prog
,
state_dict
)
def
load_test_weights
(
self
,
exe
,
weights
,
prog
,
place
):
params_list
=
list
(
filter
(
is_parameter
,
prog
.
list_vars
()))
fluid
.
load
(
prog
,
weights
,
executor
=
exe
,
var_list
=
params_list
)
def
get_config_from_sec
(
self
,
sec
,
item
,
default
=
None
):
if
sec
.
upper
()
not
in
self
.
cfg
:
return
default
return
self
.
cfg
[
sec
.
upper
()].
get
(
item
,
default
)
class
ModelZoo
(
object
):
def
__init__
(
self
):
self
.
model_zoo
=
{}
def
regist
(
self
,
name
,
model
):
assert
model
.
__base__
==
ModelBase
,
"Unknow model type {}"
.
format
(
type
(
model
))
self
.
model_zoo
[
name
]
=
model
def
get
(
self
,
name
,
cfg
,
mode
=
'train'
):
for
k
,
v
in
self
.
model_zoo
.
items
():
if
k
.
upper
()
==
name
.
upper
():
return
v
(
name
,
cfg
,
mode
)
raise
ModelNotFoundError
(
name
,
self
.
model_zoo
.
keys
())
# singleton model_zoo
model_zoo
=
ModelZoo
()
def
regist_model
(
name
,
model
):
model_zoo
.
regist
(
name
,
model
)
def
get_model
(
name
,
cfg
,
mode
=
'train'
):
return
model_zoo
.
get
(
name
,
cfg
,
mode
)
applications/EDVR/models/utils.py
已删除
100644 → 0
浏览文件 @
cac2fbdf
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
os
import
wget
import
tarfile
__all__
=
[
'decompress'
,
'download'
,
'AttrDict'
]
def
decompress
(
path
):
t
=
tarfile
.
open
(
path
)
t
.
extractall
(
path
=
os
.
path
.
split
(
path
)[
0
])
t
.
close
()
os
.
remove
(
path
)
def
download
(
url
,
path
):
weight_dir
=
os
.
path
.
split
(
path
)[
0
]
if
not
os
.
path
.
exists
(
weight_dir
):
os
.
makedirs
(
weight_dir
)
path
=
path
+
".tar.gz"
wget
.
download
(
url
,
path
)
decompress
(
path
)
class
AttrDict
(
dict
):
def
__getattr__
(
self
,
key
):
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
if
key
in
self
.
__dict__
:
self
.
__dict__
[
key
]
=
value
else
:
self
[
key
]
=
value
applications/EDVR/predict.py
浏览文件 @
542c8839
...
...
@@ -27,7 +27,7 @@ import paddle.fluid as fluid
import
cv2
from
utils.config_utils
import
*
import
models
#
import models
from
reader
import
get_reader
#from metrics import get_metrics
from
utils.utility
import
check_cuda
...
...
@@ -112,7 +112,7 @@ def infer(args):
infer_config
=
merge_configs
(
config
,
'infer'
,
vars
(
args
))
print_configs
(
infer_config
,
"Infer"
)
model_path
=
'/workspace/
video_test/video/for_eval
/data/inference_model'
model_path
=
'/workspace/
PaddleGAN/applications/EDVR
/data/inference_model'
model_filename
=
'EDVR_model.pdmodel'
params_filename
=
'EDVR_params.pdparams'
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
...
...
applications/EDVR/reader/edvr_reader.py
浏览文件 @
542c8839
...
...
@@ -280,6 +280,7 @@ def read_img(path, size=None, is_gt=False):
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
#print("path: ", path)
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录