Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
9ce257c6
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ce257c6
编写于
10月 13, 2020
作者:
L
LielinJiang
提交者:
GitHub
10月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #31 from LielinJiang/transforms
Reproduce transforms module
上级
2354ab9d
a24c71dc
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
478 addition
and
216 deletion
+478
-216
applications/first_order_model/configs/vox-256.yaml
applications/first_order_model/configs/vox-256.yaml
+0
-55
applications/tools/first-order-demo.py
applications/tools/first-order-demo.py
+21
-8
applications/tools/video-enhance.py
applications/tools/video-enhance.py
+14
-0
configs/cyclegan_cityscapes.yaml
configs/cyclegan_cityscapes.yaml
+20
-21
configs/cyclegan_horse2zebra.yaml
configs/cyclegan_horse2zebra.yaml
+20
-19
configs/pix2pix_cityscapes.yaml
configs/pix2pix_cityscapes.yaml
+29
-20
configs/pix2pix_cityscapes_2gpus.yaml
configs/pix2pix_cityscapes_2gpus.yaml
+28
-20
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+28
-20
ppgan/datasets/paired_dataset.py
ppgan/datasets/paired_dataset.py
+8
-21
ppgan/datasets/transforms/__init__.py
ppgan/datasets/transforms/__init__.py
+1
-0
ppgan/datasets/transforms/builder.py
ppgan/datasets/transforms/builder.py
+46
-0
ppgan/datasets/transforms/transforms.py
ppgan/datasets/transforms/transforms.py
+241
-5
ppgan/datasets/unpaired_dataset.py
ppgan/datasets/unpaired_dataset.py
+20
-12
ppgan/models/builder.py
ppgan/models/builder.py
+0
-9
ppgan/utils/animate.py
ppgan/utils/animate.py
+2
-6
未找到文件。
applications/first_order_model/configs/vox-256.yaml
浏览文件 @
9ce257c6
dataset_params
:
root_dir
:
data/vox-png
frame_shape
:
[
256
,
256
,
3
]
id_sampling
:
True
pairs_list
:
data/vox256.csv
augmentation_params
:
flip_param
:
horizontal_flip
:
True
time_flip
:
True
jitter_param
:
brightness
:
0.1
contrast
:
0.1
saturation
:
0.1
hue
:
0.1
model_params
:
common_params
:
num_kp
:
10
...
...
@@ -42,42 +26,3 @@ model_params:
max_features
:
512
num_blocks
:
4
sn
:
True
train_params
:
num_epochs
:
100
num_repeats
:
75
epoch_milestones
:
[
60
,
90
]
lr_generator
:
2.0e-4
lr_discriminator
:
2.0e-4
lr_kp_detector
:
2.0e-4
batch_size
:
40
scales
:
[
1
,
0.5
,
0.25
,
0.125
]
checkpoint_freq
:
50
transform_params
:
sigma_affine
:
0.05
sigma_tps
:
0.005
points_tps
:
5
loss_weights
:
generator_gan
:
0
discriminator_gan
:
1
feature_matching
:
[
10
,
10
,
10
,
10
]
perceptual
:
[
10
,
10
,
10
,
10
,
10
]
equivariance_value
:
10
equivariance_jacobian
:
10
reconstruction_params
:
num_videos
:
1000
format
:
'
.mp4'
animate_params
:
num_pairs
:
50
format
:
'
.mp4'
normalization_params
:
adapt_movement_scale
:
False
use_relative_movement
:
True
use_relative_jacobian
:
True
visualizer_params
:
kp_size
:
5
draw_border
:
True
colormap
:
'
gist_rainbow'
applications/tools/first-order-demo.py
浏览文件 @
9ce257c6
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
os
...
...
@@ -5,20 +19,20 @@ import sys
import
yaml
import
pickle
from
argparse
import
ArgumentParser
from
tqdm
import
tqdm
import
imageio
import
numpy
as
np
from
skimage.transform
import
resize
from
tqdm
import
tqdm
from
skimage
import
img_as_ubyte
import
paddle
from
argparse
import
ArgumentParser
from
skimage.transform
import
resize
from
scipy.spatial
import
ConvexHull
from
ppgan.models.generators.occlusion_aware
import
OcclusionAwareGenerator
from
ppgan.modules.keypoint_detector
import
KPDetector
from
ppgan.utils.animate
import
normalize_kp
from
scipy.spatial
import
ConvexHull
import
paddle
paddle
.
disable_static
()
if
sys
.
version_info
[
0
]
<
3
:
...
...
@@ -60,8 +74,7 @@ def make_animation(source_image,
predictions
=
[]
source
=
paddle
.
to_tensor
(
source_image
[
np
.
newaxis
].
astype
(
np
.
float32
)).
transpose
([
0
,
3
,
1
,
2
])
# if not cpu:
# source = source.cuda()
driving
=
paddle
.
to_tensor
(
np
.
array
(
driving_video
)[
np
.
newaxis
].
astype
(
np
.
float32
)).
transpose
(
[
0
,
4
,
1
,
2
,
3
])
...
...
applications/tools/video-enhance.py
浏览文件 @
9ce257c6
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
'.'
)
...
...
configs/cyclegan_cityscapes.yaml
浏览文件 @
9ce257c6
...
...
@@ -36,16 +36,18 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
interpolation
:
2
#cv2.INTER_CUBIC
-
name
:
RandomCrop
output_size
:
[
256
,
256
]
-
name
:
RandomHorizontalFlip
prob
:
0.5
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
test
:
name
:
SingleDataset
dataroot
:
data/cityscapes/testB
...
...
@@ -55,17 +57,14 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
interpolation
:
2
#cv2.INTER_CUBIC
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
...
...
configs/cyclegan_horse2zebra.yaml
浏览文件 @
9ce257c6
...
...
@@ -35,16 +35,18 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
interpolation
:
2
#cv2.INTER_CUBIC
-
name
:
RandomCrop
output_size
:
[
256
,
256
]
-
name
:
RandomHorizontalFlip
prob
:
0.5
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
test
:
name
:
SingleDataset
dataroot
:
data/horse2zebra/testA
...
...
@@ -55,15 +57,14 @@ dataset:
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transform
:
-
name
:
Resize
size
:
[
256
,
256
]
interpolation
:
2
#cv2.INTER_CUBIC
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_cityscapes.yaml
浏览文件 @
9ce257c6
...
...
@@ -33,16 +33,23 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/cityscapes/
...
...
@@ -53,16 +60,18 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_cityscapes_2gpus.yaml
浏览文件 @
9ce257c6
...
...
@@ -32,16 +32,23 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/cityscapes/
...
...
@@ -52,16 +59,17 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_facades.yaml
浏览文件 @
9ce257c6
...
...
@@ -32,16 +32,23 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/facades/
...
...
@@ -52,16 +59,17 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
interpolation
:
2
#cv2.INTER_CUBIC
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
ppgan/datasets/paired_dataset.py
浏览文件 @
9ce257c6
...
...
@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform
from
.image_folder
import
make_dataset
from
.builder
import
DATASETS
from
.transforms.builder
import
build_transforms
@
DATASETS
.
register
()
class
PairedDataset
(
BaseDataset
):
"""A dataset class for paired image dataset.
"""
def
__init__
(
self
,
cfg
):
"""Initialize this dataset class.
...
...
@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset
.
__init__
(
self
,
cfg
)
self
.
dir_AB
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
)
# get the image directory
self
.
AB_paths
=
sorted
(
make_dataset
(
self
.
dir_AB
,
cfg
.
max_dataset_size
))
# get image paths
assert
(
self
.
cfg
.
transform
.
load_size
>=
self
.
cfg
.
transform
.
crop_size
)
# crop_size should be smaller than the size of loaded image
self
.
dir_AB
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
)
# get the image directory
self
.
AB_paths
=
sorted
(
make_dataset
(
self
.
dir_AB
,
cfg
.
max_dataset_size
))
# get image paths
self
.
input_nc
=
self
.
cfg
.
output_nc
if
self
.
cfg
.
direction
==
'BtoA'
else
self
.
cfg
.
input_nc
self
.
output_nc
=
self
.
cfg
.
input_nc
if
self
.
cfg
.
direction
==
'BtoA'
else
self
.
cfg
.
output_nc
self
.
transforms
=
build_transforms
(
cfg
.
transforms
)
def
__getitem__
(
self
,
index
):
"""Return a data point and its metadata information.
...
...
@@ -49,27 +52,11 @@ class PairedDataset(BaseDataset):
A
=
AB
[:
h
,
:
w2
,
:]
B
=
AB
[:
h
,
w2
:,
:]
# apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
transform_params
=
get_params
(
self
.
cfg
.
transform
,
(
w2
,
h
))
A_transform
=
get_transform
(
self
.
cfg
.
transform
,
transform_params
,
grayscale
=
(
self
.
input_nc
==
1
))
B_transform
=
get_transform
(
self
.
cfg
.
transform
,
transform_params
,
grayscale
=
(
self
.
output_nc
==
1
))
A
=
A_transform
(
A
)
B
=
B_transform
(
B
)
A
,
B
=
self
.
transforms
((
A
,
B
))
return
{
'A'
:
A
,
'B'
:
B
,
'A_paths'
:
AB_path
,
'B_paths'
:
AB_path
}
def
__len__
(
self
):
"""Return the total number of images in the dataset."""
return
len
(
self
.
AB_paths
)
def
get_path_by_indexs
(
self
,
indexs
):
if
isinstance
(
indexs
,
paddle
.
Variable
):
indexs
=
indexs
.
numpy
()
current_paths
=
[]
for
index
in
indexs
:
current_paths
.
append
(
self
.
AB_paths
[
index
])
return
current_paths
ppgan/datasets/transforms/__init__.py
0 → 100644
浏览文件 @
9ce257c6
from
.transforms
import
RandomCrop
,
Resize
,
RandomHorizontalFlip
,
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
Normalize
,
Permute
ppgan/datasets/transforms/builder.py
0 → 100644
浏览文件 @
9ce257c6
import
copy
import
traceback
import
paddle
from
...utils.registry
import
Registry
TRANSFORMS
=
Registry
(
"TRANSFORMS"
)
class
Compose
(
object
):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
data
):
for
f
in
self
.
transforms
:
try
:
data
=
f
(
data
)
except
Exception
as
e
:
stack_info
=
traceback
.
format_exc
()
print
(
"fail to perform transform [{}] with error: "
"{} and stack:
\n
{}"
.
format
(
f
,
e
,
str
(
stack_info
)))
raise
e
return
data
def
build_transforms
(
cfg
):
transforms
=
[]
for
trans_cfg
in
cfg
:
temp_trans_cfg
=
copy
.
deepcopy
(
trans_cfg
)
name
=
temp_trans_cfg
.
pop
(
'name'
)
transforms
.
append
(
TRANSFORMS
.
get
(
name
)(
**
temp_trans_cfg
))
transforms
=
Compose
(
transforms
)
return
transforms
ppgan/datasets/transforms/transforms.py
浏览文件 @
9ce257c6
import
sys
import
random
import
numbers
import
collections
import
numpy
as
np
from
paddle.utils
import
try_import
import
paddle.vision.transforms.functional
as
F
class
RandomCrop
(
object
):
from
.builder
import
TRANSFORMS
def
__init__
(
self
,
output_size
):
if
sys
.
version_info
<
(
3
,
3
):
Sequence
=
collections
.
Sequence
Iterable
=
collections
.
Iterable
else
:
Sequence
=
collections
.
abc
.
Sequence
Iterable
=
collections
.
abc
.
Iterable
class
Transform
():
def
_set_attributes
(
self
,
args
):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if
args
:
for
k
,
v
in
args
.
items
():
if
k
!=
"self"
and
not
k
.
startswith
(
"_"
):
setattr
(
self
,
k
,
v
)
def
apply_image
(
self
,
input
):
raise
NotImplementedError
def
__call__
(
self
,
inputs
):
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
if
self
.
keys
is
not
None
:
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
])
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
])
else
:
inputs
=
self
.
apply_image
(
inputs
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
Resize
(
Transform
):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def
__init__
(
self
,
size
,
interpolation
=
1
,
keys
=
None
):
super
().
__init__
()
assert
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
Iterable
)
and
len
(
size
)
==
2
)
self
.
_set_attributes
(
locals
())
if
isinstance
(
self
.
size
,
Iterable
):
self
.
size
=
tuple
(
size
)
def
apply_image
(
self
,
img
):
return
F
.
resize
(
img
,
self
.
size
,
self
.
interpolation
)
@
TRANSFORMS
.
register
()
class
RandomCrop
(
Transform
):
def
__init__
(
self
,
output_size
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
if
isinstance
(
output_size
,
int
):
self
.
output_size
=
(
output_size
,
output_size
)
else
:
...
...
@@ -19,12 +105,162 @@ class RandomCrop(object):
j
=
random
.
randint
(
0
,
w
-
tw
)
return
i
,
j
,
th
,
tw
def
__call__
(
self
,
img
):
def
apply_image
(
self
,
img
):
i
,
j
,
h
,
w
=
self
.
_get_params
(
img
)
cropped_img
=
img
[
i
:
i
+
h
,
j
:
j
+
w
]
return
cropped_img
@
TRANSFORMS
.
register
()
class
PairedRandomCrop
(
RandomCrop
):
def
__init__
(
self
,
output_size
,
keys
=
None
):
super
().
__init__
(
output_size
,
keys
)
if
isinstance
(
output_size
,
int
):
self
.
output_size
=
(
output_size
,
output_size
)
else
:
self
.
output_size
=
output_size
def
apply_image
(
self
,
img
,
crop_prams
=
None
):
if
crop_prams
is
not
None
:
i
,
j
,
h
,
w
=
crop_prams
else
:
i
,
j
,
h
,
w
=
self
.
_get_params
(
img
)
cropped_img
=
img
[
i
:
i
+
h
,
j
:
j
+
w
]
return
cropped_img
def
__call__
(
self
,
inputs
):
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
if
self
.
keys
is
not
None
:
if
isinstance
(
inputs
,
dict
):
crop_params
=
self
.
_get_params
(
inputs
[
self
.
keys
[
0
]])
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
crop_params
=
self
.
_get_params
(
inputs
[
0
])
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
],
crop_params
)
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
],
crop_params
)
else
:
crop_params
=
self
.
_get_params
(
inputs
)
inputs
=
self
.
apply_image
(
inputs
,
crop_params
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
RandomHorizontalFlip
(
Transform
):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def
__init__
(
self
,
prob
=
0.5
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
def
apply_image
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
return
F
.
flip
(
img
,
code
=
1
)
return
img
@
TRANSFORMS
.
register
()
class
PairedRandomHorizontalFlip
(
RandomHorizontalFlip
):
def
__init__
(
self
,
prob
=
0.5
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
def
apply_image
(
self
,
img
,
flip
):
if
flip
:
return
F
.
flip
(
img
,
code
=
1
)
return
img
def
__call__
(
self
,
inputs
):
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
flip
=
np
.
random
.
random
()
<
self
.
prob
if
self
.
keys
is
not
None
:
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
],
flip
)
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
],
flip
)
else
:
inputs
=
self
.
apply_image
(
inputs
,
flip
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
Normalize
(
Transform
):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def
__init__
(
self
,
mean
=
0.0
,
std
=
1.0
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
if
isinstance
(
mean
,
numbers
.
Number
):
mean
=
[
mean
,
mean
,
mean
]
if
isinstance
(
std
,
numbers
.
Number
):
std
=
[
std
,
std
,
std
]
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
).
reshape
(
len
(
mean
),
1
,
1
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
).
reshape
(
len
(
std
),
1
,
1
)
def
apply_image
(
self
,
img
):
return
(
img
-
self
.
mean
)
/
self
.
std
@
TRANSFORMS
.
register
()
class
Permute
(
Transform
):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def
__init__
(
self
,
mode
=
"CHW"
,
to_rgb
=
True
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
assert
mode
in
[
"CHW"
],
"Only support 'CHW' mode, but received mode: {}"
.
format
(
mode
)
self
.
mode
=
mode
self
.
to_rgb
=
to_rgb
def
apply_image
(
self
,
img
):
if
self
.
to_rgb
:
img
=
img
[...,
::
-
1
]
if
self
.
mode
==
"CHW"
:
return
img
.
transpose
((
2
,
0
,
1
))
return
img
class
Crop
():
def
__init__
(
self
,
pos
,
size
):
self
.
pos
=
pos
...
...
@@ -35,6 +271,6 @@ class Crop():
x
,
y
=
self
.
pos
th
=
tw
=
self
.
size
if
(
ow
>
tw
or
oh
>
th
):
return
img
[
y
:
y
+
th
,
x
:
x
+
tw
]
return
img
[
y
:
y
+
th
,
x
:
x
+
tw
]
return
img
\ No newline at end of file
return
img
ppgan/datasets/unpaired_dataset.py
浏览文件 @
9ce257c6
...
...
@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform
from
.image_folder
import
make_dataset
from
.builder
import
DATASETS
from
.transforms.builder
import
build_transforms
@
DATASETS
.
register
()
class
UnpairedDataset
(
BaseDataset
):
"""
"""
def
__init__
(
self
,
cfg
):
"""Initialize this dataset class.
...
...
@@ -19,18 +19,25 @@ class UnpairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset
.
__init__
(
self
,
cfg
)
self
.
dir_A
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'A'
)
# create a path '/path/to/data/trainA'
self
.
dir_B
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'B'
)
# create a path '/path/to/data/trainB'
self
.
dir_A
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'A'
)
# create a path '/path/to/data/trainA'
self
.
dir_B
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'B'
)
# create a path '/path/to/data/trainB'
self
.
A_paths
=
sorted
(
make_dataset
(
self
.
dir_A
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainA'
self
.
B_paths
=
sorted
(
make_dataset
(
self
.
dir_B
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainB'
self
.
A_paths
=
sorted
(
make_dataset
(
self
.
dir_A
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainA'
self
.
B_paths
=
sorted
(
make_dataset
(
self
.
dir_B
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainB'
self
.
A_size
=
len
(
self
.
A_paths
)
# get the size of dataset A
self
.
B_size
=
len
(
self
.
B_paths
)
# get the size of dataset B
btoA
=
self
.
cfg
.
direction
==
'BtoA'
input_nc
=
self
.
cfg
.
output_nc
if
btoA
else
self
.
cfg
.
input_nc
# get the number of channels of input image
output_nc
=
self
.
cfg
.
input_nc
if
btoA
else
self
.
cfg
.
output_nc
# get the number of channels of output image
self
.
transform_A
=
get_transform
(
self
.
cfg
.
transform
,
grayscale
=
(
input_nc
==
1
))
self
.
transform_B
=
get_transform
(
self
.
cfg
.
transform
,
grayscale
=
(
output_nc
==
1
))
input_nc
=
self
.
cfg
.
output_nc
if
btoA
else
self
.
cfg
.
input_nc
# get the number of channels of input image
output_nc
=
self
.
cfg
.
input_nc
if
btoA
else
self
.
cfg
.
output_nc
# get the number of channels of output image
self
.
transform_A
=
build_transforms
(
self
.
cfg
.
transforms
)
self
.
transform_B
=
build_transforms
(
self
.
cfg
.
transforms
)
self
.
reset_paths
()
...
...
@@ -49,10 +56,11 @@ class UnpairedDataset(BaseDataset):
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path
=
self
.
A_paths
[
index
%
self
.
A_size
]
# make sure index is within then range
if
self
.
cfg
.
serial_batches
:
# make sure index is within then range
A_path
=
self
.
A_paths
[
index
%
self
.
A_size
]
# make sure index is within then range
if
self
.
cfg
.
serial_batches
:
# make sure index is within then range
index_B
=
index
%
self
.
B_size
else
:
# randomize the index for domain B to avoid fixed pairs.
else
:
# randomize the index for domain B to avoid fixed pairs.
index_B
=
random
.
randint
(
0
,
self
.
B_size
-
1
)
B_path
=
self
.
B_paths
[
index_B
]
...
...
ppgan/models/builder.py
浏览文件 @
9ce257c6
...
...
@@ -2,18 +2,9 @@ import paddle
from
..utils.registry
import
Registry
MODELS
=
Registry
(
"MODEL"
)
def
build_model
(
cfg
):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model
=
MODELS
.
get
(
cfg
.
model
.
name
)(
cfg
)
return
model
# pass
\ No newline at end of file
ppgan/utils/animate.py
浏览文件 @
9ce257c6
import
os
from
tqdm
import
tqdm
import
numpy
as
np
from
scipy.spatial
import
ConvexHull
import
paddle
import
imageio
from
scipy.spatial
import
ConvexHull
import
numpy
as
np
def
normalize_kp
(
kp_source
,
kp_driving
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录