Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
4a3ba224
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
98
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4a3ba224
编写于
8月 31, 2020
作者:
L
LielinJiang
提交者:
GitHub
8月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12 from LielinJiang/add-deep-remaster
Add DeepRemaster
上级
5b31853d
77090941
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
548 addition
and
16 deletion
+548
-16
applications/DAIN/predict.py
applications/DAIN/predict.py
+1
-1
applications/DeOldify/predict.py
applications/DeOldify/predict.py
+1
-1
applications/DeepRemaster/predict.py
applications/DeepRemaster/predict.py
+209
-0
applications/DeepRemaster/remasternet.py
applications/DeepRemaster/remasternet.py
+187
-0
applications/DeepRemaster/utils.py
applications/DeepRemaster/utils.py
+35
-0
applications/EDVR/data.py
applications/EDVR/data.py
+90
-0
applications/run.sh
applications/run.sh
+1
-1
applications/tools/video-enhance.py
applications/tools/video-enhance.py
+24
-13
未找到文件。
applications/DAIN/predict.py
浏览文件 @
4a3ba224
...
@@ -252,7 +252,7 @@ class VideoFrameInterp(object):
...
@@ -252,7 +252,7 @@ class VideoFrameInterp(object):
for
item
,
time_offset
in
zip
(
y_
,
time_offsets
):
for
item
,
time_offset
in
zip
(
y_
,
time_offsets
):
out_dir
=
os
.
path
.
join
(
out_dir
=
os
.
path
.
join
(
frame_path_interpolated
,
vidname
,
frame_path_interpolated
,
vidname
,
"{:0>
4
d}_{:0>4d}.png"
.
format
(
i
,
count
))
"{:0>
6
d}_{:0>4d}.png"
.
format
(
i
,
count
))
count
=
count
+
1
count
=
count
+
1
imsave
(
out_dir
,
np
.
round
(
item
).
astype
(
np
.
uint8
))
imsave
(
out_dir
,
np
.
round
(
item
).
astype
(
np
.
uint8
))
...
...
applications/DeOldify/predict.py
浏览文件 @
4a3ba224
...
@@ -74,7 +74,7 @@ class DeOldifyPredictor():
...
@@ -74,7 +74,7 @@ class DeOldifyPredictor():
img
+=
img_mean
img
+=
img_mean
img
=
img
.
transpose
((
1
,
2
,
0
))
img
=
img
.
transpose
((
1
,
2
,
0
))
return
(
img
*
255
).
astype
(
'uint8'
)
return
(
img
*
255
).
clip
(
0
,
255
).
astype
(
'uint8'
)
def
post_process
(
self
,
raw_color
,
orig
):
def
post_process
(
self
,
raw_color
,
orig
):
color_np
=
np
.
asarray
(
raw_color
)
color_np
=
np
.
asarray
(
raw_color
)
...
...
applications/DeepRemaster/predict.py
0 → 100644
浏览文件 @
4a3ba224
import
os
import
sys
cur_path
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
append
(
cur_path
)
import
paddle
import
paddle.nn
as
nn
import
cv2
from
PIL
import
Image
import
numpy
as
np
from
tqdm
import
tqdm
import
argparse
import
subprocess
import
utils
from
remasternet
import
NetworkR
,
NetworkC
from
paddle.incubate.hapi.download
import
get_path_from_url
DeepRemaster_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser
=
argparse
.
ArgumentParser
(
description
=
'Remastering'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--reference_dir'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster without colorization'
)
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
'360'
,
help
=
'Length of minimum image edges'
)
class
DeepReasterPredictor
:
def
__init__
(
self
,
input
,
output
,
weight_path
=
None
,
colorization
=
False
,
reference_dir
=
None
,
mindim
=
360
):
self
.
input
=
input
self
.
output
=
os
.
path
.
join
(
output
,
'DeepRemaster'
)
self
.
colorization
=
colorization
self
.
reference_dir
=
reference_dir
self
.
mindim
=
mindim
if
weight_path
is
None
:
weight_path
=
get_path_from_url
(
DeepRemaster_weight_url
,
cur_path
)
state_dict
,
_
=
paddle
.
load
(
weight_path
)
self
.
modelR
=
NetworkR
()
self
.
modelR
.
load_dict
(
state_dict
[
'modelR'
])
self
.
modelR
.
eval
()
if
colorization
:
self
.
modelC
=
NetworkC
()
self
.
modelC
.
load_dict
(
state_dict
[
'modelC'
])
self
.
modelC
.
eval
()
def
run
(
self
):
outputdir
=
self
.
output
outputdir_in
=
os
.
path
.
join
(
outputdir
,
'input/'
)
os
.
makedirs
(
outputdir_in
,
exist_ok
=
True
)
outputdir_out
=
os
.
path
.
join
(
outputdir
,
'output/'
)
os
.
makedirs
(
outputdir_out
,
exist_ok
=
True
)
# Prepare reference images
if
self
.
colorization
:
if
self
.
reference_dir
is
not
None
:
import
glob
ext_list
=
[
'png'
,
'jpg'
,
'bmp'
]
reference_files
=
[]
for
ext
in
ext_list
:
reference_files
+=
glob
.
glob
(
self
.
reference_dir
+
'/*.'
+
ext
,
recursive
=
True
)
aspect_mean
=
0
minedge_dim
=
256
refs
=
[]
for
v
in
reference_files
:
refimg
=
Image
.
open
(
v
).
convert
(
'RGB'
)
w
,
h
=
refimg
.
size
aspect_mean
+=
w
/
h
refs
.
append
(
refimg
)
aspect_mean
/=
len
(
reference_files
)
target_w
=
int
(
256
*
aspect_mean
)
if
aspect_mean
>
1
else
256
target_h
=
256
if
aspect_mean
>=
1
else
int
(
256
/
aspect_mean
)
refimgs
=
[]
for
i
,
v
in
enumerate
(
refs
):
refimg
=
utils
.
addMergin
(
v
,
target_w
=
target_w
,
target_h
=
target_h
)
refimg
=
np
.
array
(
refimg
).
astype
(
'float32'
).
transpose
(
2
,
0
,
1
)
/
255.0
refimgs
.
append
(
refimg
)
refimgs
=
paddle
.
to_tensor
(
np
.
array
(
refimgs
).
astype
(
'float32'
))
refimgs
=
paddle
.
unsqueeze
(
refimgs
,
0
)
# Load video
cap
=
cv2
.
VideoCapture
(
self
.
input
)
nframes
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
v_w
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
v_h
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
minwh
=
min
(
v_w
,
v_h
)
scale
=
1
if
minwh
!=
self
.
mindim
:
scale
=
self
.
mindim
/
minwh
t_w
=
round
(
v_w
*
scale
/
16.
)
*
16
t_h
=
round
(
v_h
*
scale
/
16.
)
*
16
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
pbar
=
tqdm
(
total
=
nframes
)
block
=
5
# Process
with
paddle
.
no_grad
():
it
=
0
while
True
:
frame_pos
=
it
*
block
if
frame_pos
>=
nframes
:
break
cap
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
frame_pos
)
if
block
>=
nframes
-
frame_pos
:
proc_g
=
nframes
-
frame_pos
else
:
proc_g
=
block
input
=
None
gtC
=
None
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
_
,
frame
=
cap
.
read
()
frame
=
cv2
.
resize
(
frame
,
(
t_w
,
t_h
))
nchannels
=
frame
.
shape
[
2
]
if
nchannels
==
1
or
self
.
colorization
:
frame_l
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2GRAY
)
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame_l
)
frame_l
=
paddle
.
to_tensor
(
frame_l
.
astype
(
'float32'
))
frame_l
=
paddle
.
reshape
(
frame_l
,
[
frame_l
.
shape
[
0
],
frame_l
.
shape
[
1
],
1
])
frame_l
=
paddle
.
transpose
(
frame_l
,
[
2
,
0
,
1
])
frame_l
/=
255.
frame_l
=
paddle
.
reshape
(
frame_l
,
[
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
elif
nchannels
==
3
:
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame
)
frame
=
frame
[:,:,::
-
1
]
## BGR -> RGB
frame_l
,
frame_ab
=
utils
.
convertRGB2LABTensor
(
frame
)
frame_l
=
frame_l
.
transpose
([
2
,
0
,
1
])
frame_ab
=
frame_ab
.
transpose
([
2
,
0
,
1
])
frame_l
=
frame_l
.
reshape
([
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
frame_ab
=
frame_ab
.
reshape
([
1
,
frame_ab
.
shape
[
0
],
1
,
frame_ab
.
shape
[
1
],
frame_ab
.
shape
[
2
]])
if
input
is
not
None
:
paddle
.
concat
(
(
input
,
frame_l
),
2
)
input
=
frame_l
if
i
==
0
else
paddle
.
concat
(
(
input
,
frame_l
),
2
)
if
nchannels
==
3
and
not
self
.
colorization
:
gtC
=
frame_ab
if
i
==
0
else
paddle
.
concat
(
(
gtC
,
frame_ab
),
2
)
input
=
paddle
.
to_tensor
(
input
)
output_l
=
self
.
modelR
(
input
)
# [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization]
if
not
self
.
colorization
:
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
if
nchannels
==
3
:
out_l
=
output_l
.
detach
()[
0
,:,
i
]
out_ab
=
gtC
[
0
,:,
i
]
out
=
paddle
.
concat
((
out_l
,
out_ab
),
axis
=
0
).
detach
().
numpy
().
transpose
((
1
,
2
,
0
))
out
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
out
)
*
255
)
)
out
.
save
(
outputdir_out
+
'%07d.png'
%
(
index
)
)
else
:
raise
ValueError
(
'channels of imag3 must be 3!'
)
# Perform colorization
else
:
if
self
.
reference_dir
is
None
:
output_ab
=
self
.
modelC
(
output_l
)
else
:
output_ab
=
self
.
modelC
(
output_l
,
refimgs
)
output_l
=
output_l
.
detach
()
output_ab
=
output_ab
.
detach
()
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
out_l
=
output_l
[
0
,:,
i
,:,:]
out_c
=
output_ab
[
0
,:,
i
,:,:]
output
=
paddle
.
concat
((
out_l
,
out_c
),
axis
=
0
).
numpy
().
transpose
((
1
,
2
,
0
))
output
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
output
)
*
255
)
)
output
.
save
(
outputdir_out
+
'%07d.png'
%
index
)
it
=
it
+
1
pbar
.
update
(
proc_g
)
# Save result videos
outfile
=
os
.
path
.
join
(
outputdir
,
self
.
input
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
])
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4'
%
(
fps
,
outputdir_in
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4'
%
(
fps
,
outputdir_out
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4'
%
(
outfile
,
outfile
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cap
.
release
()
pbar
.
close
()
return
outputdir_out
,
'%s_out.mp4'
%
outfile
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
paddle
.
disable_static
()
predictor
=
DeepReasterPredictor
(
args
.
input
,
args
.
output
,
colorization
=
args
.
colorization
,
reference_dir
=
args
.
reference_dir
,
mindim
=
args
.
mindim
)
predictor
.
run
()
\ No newline at end of file
applications/DeepRemaster/remasternet.py
0 → 100644
浏览文件 @
4a3ba224
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
TempConv
(
nn
.
Layer
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
(
1
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
)
):
super
(
TempConv
,
self
).
__init__
()
self
.
conv3d
=
nn
.
Conv3d
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
)
self
.
bn
=
nn
.
BatchNorm
(
out_planes
)
def
forward
(
self
,
x
):
return
F
.
elu
(
self
.
bn
(
self
.
conv3d
(
x
)))
class
Upsample
(
nn
.
Layer
):
def
__init__
(
self
,
in_planes
,
out_planes
,
scale_factor
=
(
1
,
2
,
2
)):
super
(
Upsample
,
self
).
__init__
()
self
.
scale_factor
=
scale_factor
self
.
conv3d
=
nn
.
Conv3d
(
in_planes
,
out_planes
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
self
.
bn
=
nn
.
BatchNorm
(
out_planes
)
def
forward
(
self
,
x
):
out_size
=
x
.
shape
[
2
:]
for
i
in
range
(
3
):
out_size
[
i
]
=
self
.
scale_factor
[
i
]
*
out_size
[
i
]
return
F
.
elu
(
self
.
bn
(
self
.
conv3d
(
F
.
interpolate
(
x
,
size
=
out_size
,
mode
=
'trilinear'
,
align_corners
=
False
,
data_format
=
'NCDHW'
,
align_mode
=
0
))))
class
UpsampleConcat
(
nn
.
Layer
):
def
__init__
(
self
,
in_planes_up
,
in_planes_flat
,
out_planes
):
super
(
UpsampleConcat
,
self
).
__init__
()
self
.
conv3d
=
TempConv
(
in_planes_up
+
in_planes_flat
,
out_planes
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
def
forward
(
self
,
x1
,
x2
):
scale_factor
=
(
1
,
2
,
2
)
out_size
=
x1
.
shape
[
2
:]
for
i
in
range
(
3
):
out_size
[
i
]
=
scale_factor
[
i
]
*
out_size
[
i
]
x1
=
F
.
interpolate
(
x1
,
size
=
out_size
,
mode
=
'trilinear'
,
align_corners
=
False
,
data_format
=
'NCDHW'
,
align_mode
=
0
)
x
=
paddle
.
concat
([
x1
,
x2
],
axis
=
1
)
return
self
.
conv3d
(
x
)
class
SourceReferenceAttention
(
paddle
.
fluid
.
dygraph
.
Layer
):
"""
Source-Reference Attention Layer
"""
def
__init__
(
self
,
in_planes_s
,
in_planes_r
):
"""
Parameters
----------
in_planes_s: int
Number of input source feature vector channels.
in_planes_r: int
Number of input reference feature vector channels.
"""
super
(
SourceReferenceAttention
,
self
).
__init__
()
self
.
query_conv
=
nn
.
Conv3d
(
in_channels
=
in_planes_s
,
out_channels
=
in_planes_s
//
8
,
kernel_size
=
1
)
self
.
key_conv
=
nn
.
Conv3d
(
in_channels
=
in_planes_r
,
out_channels
=
in_planes_r
//
8
,
kernel_size
=
1
)
self
.
value_conv
=
nn
.
Conv3d
(
in_channels
=
in_planes_r
,
out_channels
=
in_planes_r
,
kernel_size
=
1
)
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
1
],
dtype
=
self
.
query_conv
.
weight
.
dtype
,
default_initializer
=
paddle
.
fluid
.
initializer
.
Constant
(
0.0
))
def
forward
(
self
,
source
,
reference
):
s_batchsize
,
sC
,
sT
,
sH
,
sW
=
source
.
shape
r_batchsize
,
rC
,
rT
,
rH
,
rW
=
reference
.
shape
proj_query
=
paddle
.
reshape
(
self
.
query_conv
(
source
),
[
s_batchsize
,
-
1
,
sT
*
sH
*
sW
])
proj_query
=
paddle
.
transpose
(
proj_query
,
[
0
,
2
,
1
])
proj_key
=
paddle
.
reshape
(
self
.
key_conv
(
reference
),
[
r_batchsize
,
-
1
,
rT
*
rW
*
rH
])
energy
=
paddle
.
bmm
(
proj_query
,
proj_key
)
attention
=
F
.
softmax
(
energy
)
proj_value
=
paddle
.
reshape
(
self
.
value_conv
(
reference
),
[
r_batchsize
,
-
1
,
rT
*
rH
*
rW
])
out
=
paddle
.
bmm
(
proj_value
,
paddle
.
transpose
(
attention
,
[
0
,
2
,
1
]))
out
=
paddle
.
reshape
(
out
,
[
s_batchsize
,
sC
,
sT
,
sH
,
sW
])
out
=
self
.
gamma
*
out
+
source
return
out
,
attention
class
NetworkR
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
NetworkR
,
self
).
__init__
()
self
.
layers
=
nn
.
Sequential
(
nn
.
ReplicationPad3d
((
1
,
1
,
1
,
1
,
1
,
1
)),
TempConv
(
1
,
64
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
2
,
2
),
padding
=
(
0
,
0
,
0
)
),
TempConv
(
64
,
128
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
128
,
128
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
128
,
256
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
2
,
2
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
256
,
256
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
256
,
256
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
256
,
256
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
256
,
256
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
Upsample
(
256
,
128
),
TempConv
(
128
,
64
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
TempConv
(
64
,
64
,
kernel_size
=
(
3
,
3
,
3
),
padding
=
(
1
,
1
,
1
)
),
Upsample
(
64
,
16
),
nn
.
Conv3d
(
16
,
1
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
)
def
forward
(
self
,
x
):
return
paddle
.
clip
((
x
+
paddle
.
fluid
.
layers
.
tanh
(
self
.
layers
(
((
x
*
1
).
detach
())
-
0.4462414
)
)),
0.0
,
1.0
)
class
NetworkC
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
NetworkC
,
self
).
__init__
()
self
.
down1
=
nn
.
Sequential
(
nn
.
ReplicationPad3d
((
1
,
1
,
1
,
1
,
0
,
0
)),
TempConv
(
1
,
64
,
stride
=
(
1
,
2
,
2
),
padding
=
(
0
,
0
,
0
)
),
TempConv
(
64
,
128
),
TempConv
(
128
,
128
),
TempConv
(
128
,
256
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
256
,
256
),
TempConv
(
256
,
256
),
TempConv
(
256
,
512
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
512
,
512
),
TempConv
(
512
,
512
)
)
self
.
flat
=
nn
.
Sequential
(
TempConv
(
512
,
512
),
TempConv
(
512
,
512
)
)
self
.
down2
=
nn
.
Sequential
(
TempConv
(
512
,
512
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
512
,
512
),
)
self
.
stattn1
=
SourceReferenceAttention
(
512
,
512
)
# Source-Reference Attention
self
.
stattn2
=
SourceReferenceAttention
(
512
,
512
)
# Source-Reference Attention
self
.
selfattn1
=
SourceReferenceAttention
(
512
,
512
)
# Self Attention
self
.
conv1
=
TempConv
(
512
,
512
)
self
.
up1
=
UpsampleConcat
(
512
,
512
,
512
)
# 1/8
self
.
selfattn2
=
SourceReferenceAttention
(
512
,
512
)
# Self Attention
self
.
conv2
=
TempConv
(
512
,
256
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
self
.
up2
=
nn
.
Sequential
(
Upsample
(
256
,
128
),
# 1/4
TempConv
(
128
,
64
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
)
self
.
up3
=
nn
.
Sequential
(
Upsample
(
64
,
32
),
# 1/2
TempConv
(
32
,
16
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
)
self
.
up4
=
nn
.
Sequential
(
Upsample
(
16
,
8
),
# 1/1
nn
.
Conv3d
(
8
,
2
,
kernel_size
=
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
)
)
)
self
.
reffeatnet1
=
nn
.
Sequential
(
TempConv
(
3
,
64
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
64
,
128
),
TempConv
(
128
,
128
),
TempConv
(
128
,
256
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
256
,
256
),
TempConv
(
256
,
256
),
TempConv
(
256
,
512
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
512
,
512
),
TempConv
(
512
,
512
),
)
self
.
reffeatnet2
=
nn
.
Sequential
(
TempConv
(
512
,
512
,
stride
=
(
1
,
2
,
2
)
),
TempConv
(
512
,
512
),
TempConv
(
512
,
512
),
)
def
forward
(
self
,
x
,
x_refs
=
None
):
x1
=
self
.
down1
(
x
-
0.4462414
)
if
x_refs
is
not
None
:
x_refs
=
paddle
.
transpose
(
x_refs
,
[
0
,
2
,
1
,
3
,
4
])
# [B,T,C,H,W] --> [B,C,T,H,W]
reffeat
=
self
.
reffeatnet1
(
x_refs
-
0.48
)
x1
,
_
=
self
.
stattn1
(
x1
,
reffeat
)
x2
=
self
.
flat
(
x1
)
out
=
self
.
down2
(
x1
)
if
x_refs
is
not
None
:
reffeat2
=
self
.
reffeatnet2
(
reffeat
)
out
,
_
=
self
.
stattn2
(
out
,
reffeat2
)
out
=
self
.
conv1
(
out
)
out
,
_
=
self
.
selfattn1
(
out
,
out
)
out
=
self
.
up1
(
out
,
x2
)
out
,
_
=
self
.
selfattn2
(
out
,
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
up2
(
out
)
out
=
self
.
up3
(
out
)
out
=
self
.
up4
(
out
)
return
F
.
sigmoid
(
out
)
\ No newline at end of file
applications/DeepRemaster/utils.py
0 → 100644
浏览文件 @
4a3ba224
import
paddle
from
skimage
import
color
import
numpy
as
np
from
PIL
import
Image
def
convertLAB2RGB
(
lab
):
lab
[:,
:,
0
:
1
]
=
lab
[:,
:,
0
:
1
]
*
100
# [0, 1] -> [0, 100]
lab
[:,
:,
1
:
3
]
=
np
.
clip
(
lab
[:,
:,
1
:
3
]
*
255
-
128
,
-
100
,
100
)
# [0, 1] -> [-128, 128]
rgb
=
color
.
lab2rgb
(
lab
.
astype
(
np
.
float64
)
)
return
rgb
def
convertRGB2LABTensor
(
rgb
):
lab
=
color
.
rgb2lab
(
np
.
asarray
(
rgb
)
)
# RGB -> LAB L[0, 100] a[-127, 128] b[-128, 127]
ab
=
np
.
clip
(
lab
[:,
:,
1
:
3
]
+
128
,
0
,
255
)
# AB --> [0, 255]
ab
=
paddle
.
to_tensor
(
ab
.
astype
(
'float32'
))
/
255.
L
=
lab
[:,
:,
0
]
*
2.55
# L --> [0, 255]
L
=
Image
.
fromarray
(
np
.
uint8
(
L
)
)
L
=
paddle
.
to_tensor
(
np
.
array
(
L
).
astype
(
'float32'
)[...,
np
.
newaxis
]
/
255.0
)
return
L
,
ab
def
addMergin
(
img
,
target_w
,
target_h
,
background_color
=
(
0
,
0
,
0
)):
width
,
height
=
img
.
size
if
width
==
target_w
and
height
==
target_h
:
return
img
scale
=
max
(
target_w
,
target_h
)
/
max
(
width
,
height
)
width
=
int
(
width
*
scale
/
16.
)
*
16
height
=
int
(
height
*
scale
/
16.
)
*
16
img
=
img
.
resize
((
width
,
height
),
Image
.
BICUBIC
)
xp
=
(
target_w
-
width
)
//
2
yp
=
(
target_h
-
height
)
//
2
result
=
Image
.
new
(
img
.
mode
,
(
target_w
,
target_h
),
background_color
)
result
.
paste
(
img
,
(
xp
,
yp
))
return
result
applications/EDVR/data.py
0 → 100644
浏览文件 @
4a3ba224
import
cv2
import
numpy
as
np
def
read_img
(
path
,
size
=
None
,
is_gt
=
False
):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
# print('debug:', path)
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
2
:
img
=
np
.
expand_dims
(
img
,
axis
=
2
)
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
return
img
def
get_test_neighbor_frames
(
crt_i
,
N
,
max_n
,
padding
=
'new_info'
):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n
=
max_n
-
1
n_pad
=
N
//
2
return_l
=
[]
for
i
in
range
(
crt_i
-
n_pad
,
crt_i
+
n_pad
+
1
):
if
i
<
0
:
if
padding
==
'replicate'
:
add_idx
=
0
elif
padding
==
'reflection'
:
add_idx
=
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
+
n_pad
)
+
(
-
i
)
elif
padding
==
'circle'
:
add_idx
=
N
+
i
else
:
raise
ValueError
(
'Wrong padding mode'
)
elif
i
>
max_n
:
if
padding
==
'replicate'
:
add_idx
=
max_n
elif
padding
==
'reflection'
:
add_idx
=
max_n
*
2
-
i
elif
padding
==
'new_info'
:
add_idx
=
(
crt_i
-
n_pad
)
-
(
i
-
max_n
)
elif
padding
==
'circle'
:
add_idx
=
i
-
N
else
:
raise
ValueError
(
'Wrong padding mode'
)
else
:
add_idx
=
i
return_l
.
append
(
add_idx
)
# name_b = '{:08d}'.format(crt_i)
return
return_l
class
EDVRDataset
:
def
__init__
(
self
,
frame_paths
):
self
.
frames
=
frame_paths
def
__getitem__
(
self
,
index
):
indexs
=
get_test_neighbor_frames
(
index
,
5
,
len
(
self
.
frames
))
frame_list
=
[]
for
i
in
indexs
:
img
=
read_img
(
self
.
frames
[
i
])
frame_list
.
append
(
img
)
img_LQs
=
np
.
stack
(
frame_list
,
axis
=
0
)
print
(
'img:'
,
img_LQs
.
shape
)
# BGR to RGB, HWC to CHW, numpy to tensor
img_LQs
=
img_LQs
[:,
:,
:,
[
2
,
1
,
0
]]
img_LQs
=
np
.
transpose
(
img_LQs
,
(
0
,
3
,
1
,
2
)).
astype
(
'float32'
)
return
img_LQs
,
self
.
frames
[
index
]
def
__len__
(
self
):
return
len
(
self
.
frames
)
\ No newline at end of file
applications/run.sh
浏览文件 @
4a3ba224
...
@@ -10,4 +10,4 @@ cd -
...
@@ -10,4 +10,4 @@ cd -
# proccess_order 使用模型的顺序
# proccess_order 使用模型的顺序
python tools/main.py
\
python tools/main.py
\
--input
input.mp4
--output
output
--proccess_order
DAIN DeOldify EDVR
--input
input.mp4
--output
output
--proccess_order
DAIN De
epRemaster De
Oldify EDVR
applications/tools/
main
.py
→
applications/tools/
video-enhance
.py
浏览文件 @
4a3ba224
...
@@ -5,44 +5,55 @@ import argparse
...
@@ -5,44 +5,55 @@ import argparse
import
paddle
import
paddle
from
DAIN.predict
import
VideoFrameInterp
from
DAIN.predict
import
VideoFrameInterp
from
DeepRemaster.predict
import
DeepReasterPredictor
from
DeOldify.predict
import
DeOldifyPredictor
from
DeOldify.predict
import
DeOldifyPredictor
from
EDVR.predict
import
EDVRPredictor
from
EDVR.predict
import
EDVRPredictor
parser
=
argparse
.
ArgumentParser
(
description
=
'Fix video'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Fix video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--DAIN_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--DAIN_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to model weight'
)
parser
.
add_argument
(
'--DeOldify_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--DeepRemaster_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to model weight'
)
parser
.
add_argument
(
'--EDVR_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--DeOldify_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to model weight'
)
parser
.
add_argument
(
'--EDVR_weight'
,
type
=
str
,
default
=
None
,
help
=
'Path to model weight'
)
# DAIN args
# DAIN args
parser
.
add_argument
(
'--time_step'
,
type
=
float
,
default
=
0.5
,
help
=
'choose the time steps'
)
parser
.
add_argument
(
'--time_step'
,
type
=
float
,
default
=
0.5
,
help
=
'choose the time steps'
)
# DeepRemaster args
parser
.
add_argument
(
'--reference_dir'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster with colorization'
)
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
360
,
help
=
'Length of minimum image edges'
)
#process order support model name:[DAIN, DeepRemaster, DeOldify, EDVR]
parser
.
add_argument
(
'--proccess_order'
,
type
=
str
,
default
=
'none'
,
nargs
=
'+'
,
help
=
'Process order'
)
parser
.
add_argument
(
'--proccess_order'
,
type
=
str
,
default
=
'none'
,
nargs
=
'+'
,
help
=
'Process order'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
'args...'
,
args
)
orders
=
args
.
proccess_order
orders
=
args
.
proccess_order
temp_video_path
=
None
temp_video_path
=
None
for
order
in
orders
:
for
order
in
orders
:
if
temp_video_path
is
None
:
temp_video_path
=
args
.
input
if
order
==
'DAIN'
:
if
order
==
'DAIN'
:
predictor
=
VideoFrameInterp
(
args
.
time_step
,
args
.
DAIN_weight
,
predictor
=
VideoFrameInterp
(
args
.
time_step
,
args
.
DAIN_weight
,
args
.
input
,
output_path
=
args
.
output
)
temp_video_path
,
output_path
=
args
.
output
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
elif
order
==
'DeOldify'
:
elif
order
==
'DeepRemaster'
:
print
(
'frames:'
,
frames_path
)
paddle
.
disable_static
()
print
(
'video_path:'
,
temp_video_path
)
predictor
=
DeepReasterPredictor
(
temp_video_path
,
args
.
output
,
weight_path
=
args
.
DeepRemaster_weight
,
colorization
=
args
.
colorization
,
reference_dir
=
args
.
reference_dir
,
mindim
=
args
.
mindim
)
frames_path
,
temp_video_path
=
predictor
.
run
()
paddle
.
enable_static
()
elif
order
==
'DeOldify'
:
paddle
.
disable_static
()
paddle
.
disable_static
()
predictor
=
DeOldifyPredictor
(
temp_video_path
,
args
.
output
,
weight_path
=
args
.
DeOldify_weight
)
predictor
=
DeOldifyPredictor
(
temp_video_path
,
args
.
output
,
weight_path
=
args
.
DeOldify_weight
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
print
(
'frames:'
,
frames_path
)
print
(
'video_path:'
,
temp_video_path
)
paddle
.
enable_static
()
paddle
.
enable_static
()
elif
order
==
'EDVR'
:
elif
order
==
'EDVR'
:
predictor
=
EDVRPredictor
(
temp_video_path
,
args
.
output
,
weight_path
=
args
.
EDVR_weight
)
predictor
=
EDVRPredictor
(
temp_video_path
,
args
.
output
,
weight_path
=
args
.
EDVR_weight
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
print
(
'frames:'
,
frames_path
)
print
(
'video_path:'
,
temp_video_path
)
print
(
'Model {} output frames path:'
.
format
(
order
),
frames_path
)
print
(
'Model {} output video path:'
.
format
(
order
),
temp_video_path
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录