Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
bbe1f14d
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
98
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bbe1f14d
编写于
9月 18, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add first-order-model to applications
上级
8a4848dc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
1075 addition
and
0 deletion
+1075
-0
applications/first_order_model/configs/vox-256.yaml
applications/first_order_model/configs/vox-256.yaml
+83
-0
applications/tools/first-order-demo.py
applications/tools/first-order-demo.py
+221
-0
ppgan/models/generators/occlusion_aware.py
ppgan/models/generators/occlusion_aware.py
+131
-0
ppgan/modules/dense_motion.py
ppgan/modules/dense_motion.py
+154
-0
ppgan/modules/first_order.py
ppgan/modules/first_order.py
+282
-0
ppgan/modules/keypoint_detector.py
ppgan/modules/keypoint_detector.py
+94
-0
ppgan/utils/animate.py
ppgan/utils/animate.py
+110
-0
未找到文件。
applications/first_order_model/configs/vox-256.yaml
0 → 100644
浏览文件 @
bbe1f14d
dataset_params
:
root_dir
:
data/vox-png
frame_shape
:
[
256
,
256
,
3
]
id_sampling
:
True
pairs_list
:
data/vox256.csv
augmentation_params
:
flip_param
:
horizontal_flip
:
True
time_flip
:
True
jitter_param
:
brightness
:
0.1
contrast
:
0.1
saturation
:
0.1
hue
:
0.1
model_params
:
common_params
:
num_kp
:
10
num_channels
:
3
estimate_jacobian
:
True
kp_detector_params
:
temperature
:
0.1
block_expansion
:
32
max_features
:
1024
scale_factor
:
0.25
num_blocks
:
5
generator_params
:
block_expansion
:
64
max_features
:
512
num_down_blocks
:
2
num_bottleneck_blocks
:
6
estimate_occlusion_map
:
True
dense_motion_params
:
block_expansion
:
64
max_features
:
1024
num_blocks
:
5
scale_factor
:
0.25
discriminator_params
:
scales
:
[
1
]
block_expansion
:
32
max_features
:
512
num_blocks
:
4
sn
:
True
train_params
:
num_epochs
:
100
num_repeats
:
75
epoch_milestones
:
[
60
,
90
]
lr_generator
:
2.0e-4
lr_discriminator
:
2.0e-4
lr_kp_detector
:
2.0e-4
batch_size
:
40
scales
:
[
1
,
0.5
,
0.25
,
0.125
]
checkpoint_freq
:
50
transform_params
:
sigma_affine
:
0.05
sigma_tps
:
0.005
points_tps
:
5
loss_weights
:
generator_gan
:
0
discriminator_gan
:
1
feature_matching
:
[
10
,
10
,
10
,
10
]
perceptual
:
[
10
,
10
,
10
,
10
,
10
]
equivariance_value
:
10
equivariance_jacobian
:
10
reconstruction_params
:
num_videos
:
1000
format
:
'
.mp4'
animate_params
:
num_pairs
:
50
format
:
'
.mp4'
normalization_params
:
adapt_movement_scale
:
False
use_relative_movement
:
True
use_relative_jacobian
:
True
visualizer_params
:
kp_size
:
5
draw_border
:
True
colormap
:
'
gist_rainbow'
applications/tools/first-order-demo.py
0 → 100644
浏览文件 @
bbe1f14d
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
os
import
sys
# cur_path = os.path.abspath(os.path.dirname(__file__))
# root_path = os.path.split(cur_path)[0]
# sys.path.append(root_path)
import
yaml
import
pickle
from
argparse
import
ArgumentParser
from
tqdm
import
tqdm
import
imageio
import
numpy
as
np
from
skimage.transform
import
resize
from
skimage
import
img_as_ubyte
import
paddle
from
ppgan.models.generators.occlusion_aware
import
OcclusionAwareGenerator
from
ppgan.modules.keypoint_detector
import
KPDetector
from
ppgan.utils.animate
import
normalize_kp
from
scipy.spatial
import
ConvexHull
paddle
.
disable_static
()
if
sys
.
version_info
[
0
]
<
3
:
raise
Exception
(
"You must use Python 3 or higher. Recommended version is Python 3.7"
)
def
load_checkpoints
(
config_path
,
checkpoint_path
,
cpu
=
False
):
with
open
(
config_path
)
as
f
:
config
=
yaml
.
load
(
f
)
generator
=
OcclusionAwareGenerator
(
**
config
[
'model_params'
][
'generator_params'
],
**
config
[
'model_params'
][
'common_params'
])
kp_detector
=
KPDetector
(
**
config
[
'model_params'
][
'kp_detector_params'
],
**
config
[
'model_params'
][
'common_params'
])
checkpoint
=
pickle
.
load
(
open
(
checkpoint_path
,
'rb'
))
generator
.
set_state_dict
(
checkpoint
[
'generator'
])
kp_detector
.
set_state_dict
(
checkpoint
[
'kp_detector'
])
generator
.
eval
()
kp_detector
.
eval
()
return
generator
,
kp_detector
def
make_animation
(
source_image
,
driving_video
,
generator
,
kp_detector
,
relative
=
True
,
adapt_movement_scale
=
True
,
cpu
=
False
):
with
paddle
.
no_grad
():
predictions
=
[]
source
=
paddle
.
to_tensor
(
source_image
[
np
.
newaxis
].
astype
(
np
.
float32
)).
transpose
([
0
,
3
,
1
,
2
])
# if not cpu:
# source = source.cuda()
driving
=
paddle
.
to_tensor
(
np
.
array
(
driving_video
)[
np
.
newaxis
].
astype
(
np
.
float32
)).
transpose
(
[
0
,
4
,
1
,
2
,
3
])
kp_source
=
kp_detector
(
source
)
kp_driving_initial
=
kp_detector
(
driving
[:,
:,
0
])
for
frame_idx
in
tqdm
(
range
(
driving
.
shape
[
2
])):
driving_frame
=
driving
[:,
:,
frame_idx
]
kp_driving
=
kp_detector
(
driving_frame
)
kp_norm
=
normalize_kp
(
kp_source
=
kp_source
,
kp_driving
=
kp_driving
,
kp_driving_initial
=
kp_driving_initial
,
use_relative_movement
=
relative
,
use_relative_jacobian
=
relative
,
adapt_movement_scale
=
adapt_movement_scale
)
out
=
generator
(
source
,
kp_source
=
kp_source
,
kp_driving
=
kp_norm
)
predictions
.
append
(
np
.
transpose
(
out
[
'prediction'
].
numpy
(),
[
0
,
2
,
3
,
1
])[
0
])
return
predictions
def
find_best_frame
(
source
,
driving
,
cpu
=
False
):
import
face_alignment
def
normalize_kp
(
kp
):
kp
=
kp
-
kp
.
mean
(
axis
=
0
,
keepdims
=
True
)
area
=
ConvexHull
(
kp
[:,
:
2
]).
volume
area
=
np
.
sqrt
(
area
)
kp
[:,
:
2
]
=
kp
[:,
:
2
]
/
area
return
kp
fa
=
face_alignment
.
FaceAlignment
(
face_alignment
.
LandmarksType
.
_2D
,
flip_input
=
True
,
device
=
'cpu'
if
cpu
else
'cuda'
)
kp_source
=
fa
.
get_landmarks
(
255
*
source
)[
0
]
kp_source
=
normalize_kp
(
kp_source
)
norm
=
float
(
'inf'
)
frame_num
=
0
for
i
,
image
in
tqdm
(
enumerate
(
driving
)):
kp_driving
=
fa
.
get_landmarks
(
255
*
image
)[
0
]
kp_driving
=
normalize_kp
(
kp_driving
)
new_norm
=
(
np
.
abs
(
kp_source
-
kp_driving
)
**
2
).
sum
()
if
new_norm
<
norm
:
norm
=
new_norm
frame_num
=
i
return
frame_num
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
required
=
True
,
help
=
"path to config"
)
parser
.
add_argument
(
"--checkpoint"
,
default
=
'vox-cpk.pth.tar'
,
help
=
"path to checkpoint to restore"
)
parser
.
add_argument
(
"--source_image"
,
default
=
'sup-mat/source.png'
,
help
=
"path to source image"
)
parser
.
add_argument
(
"--driving_video"
,
default
=
'sup-mat/source.png'
,
help
=
"path to driving video"
)
parser
.
add_argument
(
"--result_video"
,
default
=
'result.mp4'
,
help
=
"path to output"
)
parser
.
add_argument
(
"--relative"
,
dest
=
"relative"
,
action
=
"store_true"
,
help
=
"use relative or absolute keypoint coordinates"
)
parser
.
add_argument
(
"--adapt_scale"
,
dest
=
"adapt_scale"
,
action
=
"store_true"
,
help
=
"adapt movement scale based on convex hull of keypoints"
)
parser
.
add_argument
(
"--find_best_frame"
,
dest
=
"find_best_frame"
,
action
=
"store_true"
,
help
=
"Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)"
)
parser
.
add_argument
(
"--best_frame"
,
dest
=
"best_frame"
,
type
=
int
,
default
=
None
,
help
=
"Set frame to start from."
)
parser
.
add_argument
(
"--cpu"
,
dest
=
"cpu"
,
action
=
"store_true"
,
help
=
"cpu mode."
)
parser
.
set_defaults
(
relative
=
False
)
parser
.
set_defaults
(
adapt_scale
=
False
)
opt
=
parser
.
parse_args
()
source_image
=
imageio
.
imread
(
opt
.
source_image
)
reader
=
imageio
.
get_reader
(
opt
.
driving_video
)
fps
=
reader
.
get_meta_data
()[
'fps'
]
driving_video
=
[]
try
:
for
im
in
reader
:
driving_video
.
append
(
im
)
except
RuntimeError
:
pass
reader
.
close
()
source_image
=
resize
(
source_image
,
(
256
,
256
))[...,
:
3
]
driving_video
=
[
resize
(
frame
,
(
256
,
256
))[...,
:
3
]
for
frame
in
driving_video
]
generator
,
kp_detector
=
load_checkpoints
(
config_path
=
opt
.
config
,
checkpoint_path
=
opt
.
checkpoint
,
cpu
=
opt
.
cpu
)
if
opt
.
find_best_frame
or
opt
.
best_frame
is
not
None
:
i
=
opt
.
best_frame
if
opt
.
best_frame
is
not
None
else
find_best_frame
(
source_image
,
driving_video
,
cpu
=
opt
.
cpu
)
print
(
"Best frame: "
+
str
(
i
))
driving_forward
=
driving_video
[
i
:]
driving_backward
=
driving_video
[:(
i
+
1
)][::
-
1
]
predictions_forward
=
make_animation
(
source_image
,
driving_forward
,
generator
,
kp_detector
,
relative
=
opt
.
relative
,
adapt_movement_scale
=
opt
.
adapt_scale
,
cpu
=
opt
.
cpu
)
predictions_backward
=
make_animation
(
source_image
,
driving_backward
,
generator
,
kp_detector
,
relative
=
opt
.
relative
,
adapt_movement_scale
=
opt
.
adapt_scale
,
cpu
=
opt
.
cpu
)
predictions
=
predictions_backward
[::
-
1
]
+
predictions_forward
[
1
:]
else
:
predictions
=
make_animation
(
source_image
,
driving_video
,
generator
,
kp_detector
,
relative
=
opt
.
relative
,
adapt_movement_scale
=
opt
.
adapt_scale
,
cpu
=
opt
.
cpu
)
imageio
.
mimsave
(
opt
.
result_video
,
[
img_as_ubyte
(
frame
)
for
frame
in
predictions
],
fps
=
fps
)
ppgan/models/generators/occlusion_aware.py
0 → 100644
浏览文件 @
bbe1f14d
import
paddle
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
...modules.first_order
import
ResBlock2d
,
SameBlock2d
,
UpBlock2d
,
DownBlock2d
from
...modules.dense_motion
import
DenseMotionNetwork
class
OcclusionAwareGenerator
(
nn
.
Layer
):
"""
Generator that given source image and and keypoints try to transform image according to movement trajectories
induced by keypoints. Generator follows Johnson architecture.
"""
def
__init__
(
self
,
num_channels
,
num_kp
,
block_expansion
,
max_features
,
num_down_blocks
,
num_bottleneck_blocks
,
estimate_occlusion_map
=
False
,
dense_motion_params
=
None
,
estimate_jacobian
=
False
):
super
(
OcclusionAwareGenerator
,
self
).
__init__
()
if
dense_motion_params
is
not
None
:
self
.
dense_motion_network
=
DenseMotionNetwork
(
num_kp
=
num_kp
,
num_channels
=
num_channels
,
estimate_occlusion_map
=
estimate_occlusion_map
,
**
dense_motion_params
)
else
:
self
.
dense_motion_network
=
None
self
.
first
=
SameBlock2d
(
num_channels
,
block_expansion
,
kernel_size
=
(
7
,
7
),
padding
=
(
3
,
3
))
down_blocks
=
[]
for
i
in
range
(
num_down_blocks
):
in_features
=
min
(
max_features
,
block_expansion
*
(
2
**
i
))
out_features
=
min
(
max_features
,
block_expansion
*
(
2
**
(
i
+
1
)))
down_blocks
.
append
(
DownBlock2d
(
in_features
,
out_features
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
)))
self
.
down_blocks
=
nn
.
LayerList
(
down_blocks
)
up_blocks
=
[]
for
i
in
range
(
num_down_blocks
):
in_features
=
min
(
max_features
,
block_expansion
*
(
2
**
(
num_down_blocks
-
i
)))
out_features
=
min
(
max_features
,
block_expansion
*
(
2
**
(
num_down_blocks
-
i
-
1
)))
up_blocks
.
append
(
UpBlock2d
(
in_features
,
out_features
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
)))
self
.
up_blocks
=
nn
.
LayerList
(
up_blocks
)
self
.
bottleneck
=
paddle
.
nn
.
Sequential
()
in_features
=
min
(
max_features
,
block_expansion
*
(
2
**
num_down_blocks
))
for
i
in
range
(
num_bottleneck_blocks
):
self
.
bottleneck
.
add_sublayer
(
'r'
+
str
(
i
),
ResBlock2d
(
in_features
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
)))
self
.
final
=
nn
.
Conv2d
(
block_expansion
,
num_channels
,
kernel_size
=
(
7
,
7
),
padding
=
(
3
,
3
))
self
.
estimate_occlusion_map
=
estimate_occlusion_map
self
.
num_channels
=
num_channels
def
deform_input
(
self
,
inp
,
deformation
):
_
,
h_old
,
w_old
,
_
=
deformation
.
shape
_
,
_
,
h
,
w
=
inp
.
shape
if
h_old
!=
h
or
w_old
!=
w
:
deformation
=
deformation
.
transpose
([
0
,
3
,
1
,
2
])
deformation
=
F
.
interpolate
(
deformation
,
size
=
(
h
,
w
),
mode
=
'bilinear'
)
deformation
=
deformation
.
transpose
([
0
,
2
,
3
,
1
])
return
F
.
grid_sample
(
inp
,
deformation
)
def
forward
(
self
,
source_image
,
kp_driving
,
kp_source
):
# Encoding (downsampling) part
out
=
self
.
first
(
source_image
)
for
i
in
range
(
len
(
self
.
down_blocks
)):
out
=
self
.
down_blocks
[
i
](
out
)
# Transforming feature representation according to deformation and occlusion
output_dict
=
{}
if
self
.
dense_motion_network
is
not
None
:
dense_motion
=
self
.
dense_motion_network
(
source_image
=
source_image
,
kp_driving
=
kp_driving
,
kp_source
=
kp_source
)
output_dict
[
'mask'
]
=
dense_motion
[
'mask'
]
output_dict
[
'sparse_deformed'
]
=
dense_motion
[
'sparse_deformed'
]
if
'occlusion_map'
in
dense_motion
:
occlusion_map
=
dense_motion
[
'occlusion_map'
]
output_dict
[
'occlusion_map'
]
=
occlusion_map
else
:
occlusion_map
=
None
deformation
=
dense_motion
[
'deformation'
]
out
=
self
.
deform_input
(
out
,
deformation
)
if
occlusion_map
is
not
None
:
if
out
.
shape
[
2
]
!=
occlusion_map
.
shape
[
2
]
or
out
.
shape
[
3
]
!=
occlusion_map
.
shape
[
3
]:
occlusion_map
=
F
.
interpolate
(
occlusion_map
,
size
=
out
.
shape
[
2
:],
mode
=
'bilinear'
)
out
=
out
*
occlusion_map
output_dict
[
"deformed"
]
=
self
.
deform_input
(
source_image
,
deformation
)
# Decoding part
out
=
self
.
bottleneck
(
out
)
for
i
in
range
(
len
(
self
.
up_blocks
)):
out
=
self
.
up_blocks
[
i
](
out
)
out
=
self
.
final
(
out
)
out
=
F
.
sigmoid
(
out
)
output_dict
[
"prediction"
]
=
out
return
output_dict
ppgan/modules/dense_motion.py
0 → 100644
浏览文件 @
bbe1f14d
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
.first_order
import
Hourglass
,
AntiAliasInterpolation2d
,
make_coordinate_grid
,
kp2gaussian
class
DenseMotionNetwork
(
nn
.
Layer
):
"""
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
def
__init__
(
self
,
block_expansion
,
num_blocks
,
max_features
,
num_kp
,
num_channels
,
estimate_occlusion_map
=
False
,
scale_factor
=
1
,
kp_variance
=
0.01
):
super
(
DenseMotionNetwork
,
self
).
__init__
()
self
.
hourglass
=
Hourglass
(
block_expansion
=
block_expansion
,
in_features
=
(
num_kp
+
1
)
*
(
num_channels
+
1
),
max_features
=
max_features
,
num_blocks
=
num_blocks
)
self
.
mask
=
nn
.
Conv2d
(
self
.
hourglass
.
out_filters
,
num_kp
+
1
,
kernel_size
=
(
7
,
7
),
padding
=
(
3
,
3
))
if
estimate_occlusion_map
:
self
.
occlusion
=
nn
.
Conv2d
(
self
.
hourglass
.
out_filters
,
1
,
kernel_size
=
(
7
,
7
),
padding
=
(
3
,
3
))
else
:
self
.
occlusion
=
None
self
.
num_kp
=
num_kp
self
.
scale_factor
=
scale_factor
self
.
kp_variance
=
kp_variance
if
self
.
scale_factor
!=
1
:
self
.
down
=
AntiAliasInterpolation2d
(
num_channels
,
self
.
scale_factor
)
def
create_heatmap_representations
(
self
,
source_image
,
kp_driving
,
kp_source
):
"""
Eq 6. in the paper H_k(z)
"""
spatial_size
=
source_image
.
shape
[
2
:]
gaussian_driving
=
kp2gaussian
(
kp_driving
,
spatial_size
=
spatial_size
,
kp_variance
=
self
.
kp_variance
)
gaussian_source
=
kp2gaussian
(
kp_source
,
spatial_size
=
spatial_size
,
kp_variance
=
self
.
kp_variance
)
heatmap
=
gaussian_driving
-
gaussian_source
#adding background feature
zeros
=
paddle
.
zeros
(
[
heatmap
.
shape
[
0
],
1
,
spatial_size
[
0
],
spatial_size
[
1
]],
heatmap
.
dtype
)
#.type(heatmap.type())
heatmap
=
paddle
.
concat
([
zeros
,
heatmap
],
axis
=
1
)
heatmap
=
heatmap
.
unsqueeze
(
2
)
return
heatmap
def
create_sparse_motions
(
self
,
source_image
,
kp_driving
,
kp_source
):
"""
Eq 4. in the paper T_{s<-d}(z)
"""
bs
,
_
,
h
,
w
=
source_image
.
shape
identity_grid
=
make_coordinate_grid
((
h
,
w
),
type
=
kp_source
[
'value'
].
dtype
)
identity_grid
=
identity_grid
.
reshape
([
1
,
1
,
h
,
w
,
2
])
coordinate_grid
=
identity_grid
-
kp_driving
[
'value'
].
reshape
(
[
bs
,
self
.
num_kp
,
1
,
1
,
2
])
if
'jacobian'
in
kp_driving
:
jacobian
=
paddle
.
matmul
(
kp_source
[
'jacobian'
],
paddle
.
inverse
(
kp_driving
[
'jacobian'
]))
jacobian
=
jacobian
.
unsqueeze
(
-
3
).
unsqueeze
(
-
3
)
jacobian
=
paddle
.
tile
(
jacobian
,
[
1
,
1
,
h
,
w
,
1
,
1
])
coordinate_grid
=
paddle
.
matmul
(
jacobian
,
coordinate_grid
.
unsqueeze
(
-
1
))
coordinate_grid
=
coordinate_grid
.
squeeze
(
-
1
)
driving_to_source
=
coordinate_grid
+
kp_source
[
'value'
].
reshape
(
[
bs
,
self
.
num_kp
,
1
,
1
,
2
])
#adding background feature
identity_grid
=
paddle
.
tile
(
identity_grid
,
(
bs
,
1
,
1
,
1
,
1
))
sparse_motions
=
paddle
.
concat
([
identity_grid
,
driving_to_source
],
axis
=
1
)
return
sparse_motions
def
create_deformed_source_image
(
self
,
source_image
,
sparse_motions
):
"""
Eq 7. in the paper \hat{T}_{s<-d}(z)
"""
bs
,
_
,
h
,
w
=
source_image
.
shape
source_repeat
=
paddle
.
tile
(
source_image
.
unsqueeze
(
1
).
unsqueeze
(
1
),
[
1
,
self
.
num_kp
+
1
,
1
,
1
,
1
,
1
])
#.repeat(1, self.num_kp + 1, 1, 1, 1, 1)
source_repeat
=
source_repeat
.
reshape
(
[
bs
*
(
self
.
num_kp
+
1
),
-
1
,
h
,
w
])
sparse_motions
=
sparse_motions
.
reshape
(
(
bs
*
(
self
.
num_kp
+
1
),
h
,
w
,
-
1
))
sparse_deformed
=
F
.
grid_sample
(
source_repeat
,
sparse_motions
,
align_corners
=
False
)
sparse_deformed
=
sparse_deformed
.
reshape
(
(
bs
,
self
.
num_kp
+
1
,
-
1
,
h
,
w
))
return
sparse_deformed
def
forward
(
self
,
source_image
,
kp_driving
,
kp_source
):
if
self
.
scale_factor
!=
1
:
source_image
=
self
.
down
(
source_image
)
bs
,
_
,
h
,
w
=
source_image
.
shape
out_dict
=
dict
()
heatmap_representation
=
self
.
create_heatmap_representations
(
source_image
,
kp_driving
,
kp_source
)
sparse_motion
=
self
.
create_sparse_motions
(
source_image
,
kp_driving
,
kp_source
)
deformed_source
=
self
.
create_deformed_source_image
(
source_image
,
sparse_motion
)
out_dict
[
'sparse_deformed'
]
=
deformed_source
input
=
paddle
.
concat
([
heatmap_representation
,
deformed_source
],
axis
=
2
)
input
=
input
.
reshape
([
bs
,
-
1
,
h
,
w
])
prediction
=
self
.
hourglass
(
input
)
mask
=
self
.
mask
(
prediction
)
mask
=
F
.
softmax
(
mask
,
axis
=
1
)
out_dict
[
'mask'
]
=
mask
mask
=
mask
.
unsqueeze
(
2
)
sparse_motion
=
sparse_motion
.
transpose
([
0
,
1
,
4
,
2
,
3
])
deformation
=
(
sparse_motion
*
mask
).
sum
(
axis
=
1
)
deformation
=
deformation
.
transpose
([
0
,
2
,
3
,
1
])
out_dict
[
'deformation'
]
=
deformation
# Sec. 3.2 in the paper
if
self
.
occlusion
:
occlusion_map
=
F
.
sigmoid
(
self
.
occlusion
(
prediction
))
out_dict
[
'occlusion_map'
]
=
occlusion_map
return
out_dict
ppgan/modules/first_order.py
0 → 100644
浏览文件 @
bbe1f14d
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
# from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
def
kp2gaussian
(
kp
,
spatial_size
,
kp_variance
):
"""
Transform a keypoint into gaussian like representation
"""
mean
=
kp
[
'value'
]
coordinate_grid
=
make_coordinate_grid
(
spatial_size
,
mean
.
dtype
)
number_of_leading_dimensions
=
len
(
mean
.
shape
)
-
1
shape
=
(
1
,
)
*
number_of_leading_dimensions
+
tuple
(
coordinate_grid
.
shape
)
coordinate_grid
=
coordinate_grid
.
reshape
([
*
shape
])
repeats
=
tuple
(
mean
.
shape
[:
number_of_leading_dimensions
])
+
(
1
,
1
,
1
)
coordinate_grid
=
paddle
.
tile
(
coordinate_grid
,
[
*
repeats
])
# Preprocess kp shape
shape
=
tuple
(
mean
.
shape
[:
number_of_leading_dimensions
])
+
(
1
,
1
,
2
)
mean
=
mean
.
reshape
(
shape
)
mean_sub
=
(
coordinate_grid
-
mean
)
out
=
paddle
.
exp
(
-
0.5
*
(
mean_sub
**
2
).
sum
(
-
1
)
/
kp_variance
)
return
out
def
make_coordinate_grid
(
spatial_size
,
type
):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
h
,
w
=
spatial_size
x
=
paddle
.
arange
(
w
,
dtype
=
type
)
#.type(type)
y
=
paddle
.
arange
(
h
,
dtype
=
type
)
#.type(type)
x
=
(
2
*
(
x
/
(
w
-
1
))
-
1
)
y
=
(
2
*
(
y
/
(
h
-
1
))
-
1
)
yy
=
paddle
.
tile
(
y
.
reshape
([
-
1
,
1
]),
[
1
,
w
])
xx
=
paddle
.
tile
(
x
.
reshape
([
1
,
-
1
]),
[
h
,
1
])
meshed
=
paddle
.
concat
([
xx
.
unsqueeze
(
2
),
yy
.
unsqueeze
(
2
)],
2
)
# meshed = paddle.concat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return
meshed
class
ResBlock2d
(
nn
.
Layer
):
"""
Res block, preserve spatial resolution.
"""
def
__init__
(
self
,
in_features
,
kernel_size
,
padding
):
super
(
ResBlock2d
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
in_features
,
out_channels
=
in_features
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
in_features
,
out_channels
=
in_features
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
norm1
=
nn
.
BatchNorm2d
(
in_features
)
self
.
norm2
=
nn
.
BatchNorm2d
(
in_features
)
def
forward
(
self
,
x
):
out
=
self
.
norm1
(
x
)
out
=
F
.
relu
(
out
)
out
=
self
.
conv1
(
out
)
out
=
self
.
norm2
(
out
)
out
=
F
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
+=
x
return
out
class
UpBlock2d
(
nn
.
Layer
):
"""
Upsampling block for use in decoder.
"""
def
__init__
(
self
,
in_features
,
out_features
,
kernel_size
=
3
,
padding
=
1
,
groups
=
1
):
super
(
UpBlock2d
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_features
,
out_channels
=
out_features
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
groups
)
self
.
norm
=
nn
.
BatchNorm2d
(
out_features
)
def
forward
(
self
,
x
):
out
=
F
.
interpolate
(
x
,
scale_factor
=
2
)
out
=
self
.
conv
(
out
)
out
=
self
.
norm
(
out
)
out
=
F
.
relu
(
out
)
return
out
class
DownBlock2d
(
nn
.
Layer
):
"""
Downsampling block for use in encoder.
"""
def
__init__
(
self
,
in_features
,
out_features
,
kernel_size
=
3
,
padding
=
1
,
groups
=
1
):
super
(
DownBlock2d
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_features
,
out_channels
=
out_features
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
groups
)
self
.
norm
=
nn
.
BatchNorm2d
(
out_features
)
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
=
(
2
,
2
))
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
out
=
self
.
norm
(
out
)
out
=
F
.
relu
(
out
)
out
=
self
.
pool
(
out
)
return
out
class
SameBlock2d
(
nn
.
Layer
):
"""
Simple block, preserve spatial resolution.
"""
def
__init__
(
self
,
in_features
,
out_features
,
groups
=
1
,
kernel_size
=
3
,
padding
=
1
):
super
(
SameBlock2d
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_features
,
out_channels
=
out_features
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
groups
)
self
.
norm
=
nn
.
BatchNorm2d
(
out_features
)
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
out
=
self
.
norm
(
out
)
out
=
F
.
relu
(
out
)
return
out
class
Encoder
(
nn
.
Layer
):
"""
Hourglass Encoder
"""
def
__init__
(
self
,
block_expansion
,
in_features
,
num_blocks
=
3
,
max_features
=
256
):
super
(
Encoder
,
self
).
__init__
()
down_blocks
=
[]
for
i
in
range
(
num_blocks
):
down_blocks
.
append
(
DownBlock2d
(
in_features
if
i
==
0
else
min
(
max_features
,
block_expansion
*
(
2
**
i
)),
min
(
max_features
,
block_expansion
*
(
2
**
(
i
+
1
))),
kernel_size
=
3
,
padding
=
1
))
self
.
down_blocks
=
nn
.
LayerList
(
down_blocks
)
def
forward
(
self
,
x
):
outs
=
[
x
]
for
down_block
in
self
.
down_blocks
:
outs
.
append
(
down_block
(
outs
[
-
1
]))
return
outs
class
Decoder
(
nn
.
Layer
):
"""
Hourglass Decoder
"""
def
__init__
(
self
,
block_expansion
,
in_features
,
num_blocks
=
3
,
max_features
=
256
):
super
(
Decoder
,
self
).
__init__
()
up_blocks
=
[]
for
i
in
range
(
num_blocks
)[::
-
1
]:
in_filters
=
(
1
if
i
==
num_blocks
-
1
else
2
)
*
min
(
max_features
,
block_expansion
*
(
2
**
(
i
+
1
)))
out_filters
=
min
(
max_features
,
block_expansion
*
(
2
**
i
))
up_blocks
.
append
(
UpBlock2d
(
in_filters
,
out_filters
,
kernel_size
=
3
,
padding
=
1
))
self
.
up_blocks
=
nn
.
LayerList
(
up_blocks
)
self
.
out_filters
=
block_expansion
+
in_features
def
forward
(
self
,
x
):
out
=
x
.
pop
()
for
up_block
in
self
.
up_blocks
:
out
=
up_block
(
out
)
skip
=
x
.
pop
()
out
=
paddle
.
concat
([
out
,
skip
],
axis
=
1
)
return
out
class
Hourglass
(
nn
.
Layer
):
"""
Hourglass architecture.
"""
def
__init__
(
self
,
block_expansion
,
in_features
,
num_blocks
=
3
,
max_features
=
256
):
super
(
Hourglass
,
self
).
__init__
()
self
.
encoder
=
Encoder
(
block_expansion
,
in_features
,
num_blocks
,
max_features
)
self
.
decoder
=
Decoder
(
block_expansion
,
in_features
,
num_blocks
,
max_features
)
self
.
out_filters
=
self
.
decoder
.
out_filters
def
forward
(
self
,
x
):
return
self
.
decoder
(
self
.
encoder
(
x
))
class
AntiAliasInterpolation2d
(
nn
.
Layer
):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def
__init__
(
self
,
channels
,
scale
):
super
(
AntiAliasInterpolation2d
,
self
).
__init__
()
sigma
=
(
1
/
scale
-
1
)
/
2
kernel_size
=
2
*
round
(
sigma
*
4
)
+
1
self
.
ka
=
kernel_size
//
2
self
.
kb
=
self
.
ka
-
1
if
kernel_size
%
2
==
0
else
self
.
ka
kernel_size
=
[
kernel_size
,
kernel_size
]
sigma
=
[
sigma
,
sigma
]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel
=
1
meshgrids
=
paddle
.
meshgrid
(
[
paddle
.
arange
(
size
,
dtype
=
'float32'
)
for
size
in
kernel_size
])
for
size
,
std
,
mgrid
in
zip
(
kernel_size
,
sigma
,
meshgrids
):
mean
=
(
size
-
1
)
/
2
kernel
*=
paddle
.
exp
(
-
(
mgrid
-
mean
)
**
2
/
(
2
*
std
**
2
))
# Make sure sum of values in gaussian kernel equals 1.
kernel
=
kernel
/
paddle
.
sum
(
kernel
)
# Reshape to depthwise convolutional weight
# print('debug shape:', kernel.shape)
# print('debug shape 1:', kernel.dim())
kernel
=
kernel
.
reshape
([
1
,
1
,
*
kernel
.
shape
])
# kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
kernel
=
paddle
.
tile
(
kernel
,
[
channels
,
*
[
1
]
*
(
kernel
.
dim
()
-
1
)])
self
.
register_buffer
(
'weight'
,
kernel
)
self
.
groups
=
channels
self
.
scale
=
scale
def
forward
(
self
,
input
):
if
self
.
scale
==
1.0
:
return
input
out
=
F
.
pad
(
input
,
[
self
.
ka
,
self
.
kb
,
self
.
ka
,
self
.
kb
])
out
=
F
.
conv2d
(
out
,
weight
=
self
.
weight
,
groups
=
self
.
groups
)
out
=
F
.
interpolate
(
out
,
scale_factor
=
[
self
.
scale
,
self
.
scale
])
return
out
ppgan/modules/keypoint_detector.py
0 → 100644
浏览文件 @
bbe1f14d
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
.first_order
import
Hourglass
,
make_coordinate_grid
,
AntiAliasInterpolation2d
class
KPDetector
(
nn
.
Layer
):
"""
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
"""
def
__init__
(
self
,
block_expansion
,
num_kp
,
num_channels
,
max_features
,
num_blocks
,
temperature
,
estimate_jacobian
=
False
,
scale_factor
=
1
,
single_jacobian_map
=
False
,
pad
=
0
):
super
(
KPDetector
,
self
).
__init__
()
self
.
predictor
=
Hourglass
(
block_expansion
,
in_features
=
num_channels
,
max_features
=
max_features
,
num_blocks
=
num_blocks
)
self
.
kp
=
nn
.
Conv2d
(
in_channels
=
self
.
predictor
.
out_filters
,
out_channels
=
num_kp
,
kernel_size
=
(
7
,
7
),
padding
=
pad
)
if
estimate_jacobian
:
self
.
num_jacobian_maps
=
1
if
single_jacobian_map
else
num_kp
self
.
jacobian
=
nn
.
Conv2d
(
in_channels
=
self
.
predictor
.
out_filters
,
out_channels
=
4
*
self
.
num_jacobian_maps
,
kernel_size
=
(
7
,
7
),
padding
=
pad
)
# self.jacobian.weight.data.zero_()
# self.jacobian.bias.data.copy_(paddle.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype='float32'))
else
:
self
.
jacobian
=
None
self
.
temperature
=
temperature
self
.
scale_factor
=
scale_factor
if
self
.
scale_factor
!=
1
:
self
.
down
=
AntiAliasInterpolation2d
(
num_channels
,
self
.
scale_factor
)
def
gaussian2kp
(
self
,
heatmap
):
"""
Extract the mean and from a heatmap
"""
shape
=
heatmap
.
shape
heatmap
=
heatmap
.
unsqueeze
(
-
1
)
grid
=
make_coordinate_grid
(
shape
[
2
:],
heatmap
.
dtype
).
unsqueeze
(
0
).
unsqueeze
(
0
)
value
=
(
heatmap
*
grid
).
sum
(
axis
=
(
2
,
3
))
kp
=
{
'value'
:
value
}
return
kp
def
forward
(
self
,
x
):
if
self
.
scale_factor
!=
1
:
x
=
self
.
down
(
x
)
feature_map
=
self
.
predictor
(
x
)
prediction
=
self
.
kp
(
feature_map
)
final_shape
=
prediction
.
shape
heatmap
=
prediction
.
reshape
([
final_shape
[
0
],
final_shape
[
1
],
-
1
])
heatmap
=
F
.
softmax
(
heatmap
/
self
.
temperature
,
axis
=
2
)
heatmap
=
heatmap
.
reshape
([
*
final_shape
])
out
=
self
.
gaussian2kp
(
heatmap
)
if
self
.
jacobian
is
not
None
:
jacobian_map
=
self
.
jacobian
(
feature_map
)
jacobian_map
=
jacobian_map
.
reshape
([
final_shape
[
0
],
self
.
num_jacobian_maps
,
4
,
final_shape
[
2
],
final_shape
[
3
]
])
heatmap
=
heatmap
.
unsqueeze
(
2
)
jacobian
=
heatmap
*
jacobian_map
jacobian
=
jacobian
.
reshape
([
final_shape
[
0
],
final_shape
[
1
],
4
,
-
1
])
jacobian
=
jacobian
.
sum
(
axis
=-
1
)
jacobian
=
jacobian
.
reshape
(
[
jacobian
.
shape
[
0
],
jacobian
.
shape
[
1
],
2
,
2
])
out
[
'jacobian'
]
=
jacobian
return
out
ppgan/utils/animate.py
0 → 100644
浏览文件 @
bbe1f14d
import
os
from
tqdm
import
tqdm
import
paddle
# from paddle.utils.data import DataLoader
# from frames_dataset import PairedDataset
# from logger import Logger, Visualizer
import
imageio
from
scipy.spatial
import
ConvexHull
import
numpy
as
np
# from sync_batchnorm import DataParallelWithCallback
def
normalize_kp
(
kp_source
,
kp_driving
,
kp_driving_initial
,
adapt_movement_scale
=
False
,
use_relative_movement
=
False
,
use_relative_jacobian
=
False
):
if
adapt_movement_scale
:
# source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
# driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
source_area
=
ConvexHull
(
kp_source
[
'value'
][
0
].
numpy
()).
volume
driving_area
=
ConvexHull
(
kp_driving_initial
[
'value'
][
0
].
numpy
()).
volume
adapt_movement_scale
=
np
.
sqrt
(
source_area
)
/
np
.
sqrt
(
driving_area
)
else
:
adapt_movement_scale
=
1
kp_new
=
{
k
:
v
for
k
,
v
in
kp_driving
.
items
()}
if
use_relative_movement
:
kp_value_diff
=
(
kp_driving
[
'value'
]
-
kp_driving_initial
[
'value'
])
kp_value_diff
*=
adapt_movement_scale
kp_new
[
'value'
]
=
kp_value_diff
+
kp_source
[
'value'
]
if
use_relative_jacobian
:
jacobian_diff
=
paddle
.
matmul
(
kp_driving
[
'jacobian'
],
paddle
.
inverse
(
kp_driving_initial
[
'jacobian'
]))
kp_new
[
'jacobian'
]
=
paddle
.
matmul
(
jacobian_diff
,
kp_source
[
'jacobian'
])
return
kp_new
# def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
# log_dir = os.path.join(log_dir, 'animation')
# png_dir = os.path.join(log_dir, 'png')
# animate_params = config['animate_params']
# dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
# dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
# if checkpoint is not None:
# Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
# else:
# raise AttributeError("Checkpoint should be specified for mode='animate'.")
# if not os.path.exists(log_dir):
# os.makedirs(log_dir)
# if not os.path.exists(png_dir):
# os.makedirs(png_dir)
# if torch.cuda.is_available():
# generator = DataParallelWithCallback(generator)
# kp_detector = DataParallelWithCallback(kp_detector)
# generator.eval()
# kp_detector.eval()
# for it, x in tqdm(enumerate(dataloader)):
# with torch.no_grad():
# predictions = []
# visualizations = []
# driving_video = x['driving_video']
# source_frame = x['source_video'][:, :, 0, :, :]
# kp_source = kp_detector(source_frame)
# kp_driving_initial = kp_detector(driving_video[:, :, 0])
# for frame_idx in range(driving_video.shape[2]):
# driving_frame = driving_video[:, :, frame_idx]
# kp_driving = kp_detector(driving_frame)
# kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
# kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
# out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
# out['kp_driving'] = kp_driving
# out['kp_source'] = kp_source
# out['kp_norm'] = kp_norm
# del out['sparse_deformed']
# predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
# visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
# driving=driving_frame, out=out)
# visualization = visualization
# visualizations.append(visualization)
# predictions = np.concatenate(predictions, axis=1)
# result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
# imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
# image_name = result_name + animate_params['format']
# imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录