Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
06d3c848
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
06d3c848
编写于
9月 29, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reproduce transforms
上级
46376659
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
464 addition
and
146 deletion
+464
-146
configs/cyclegan_cityscapes.yaml
configs/cyclegan_cityscapes.yaml
+18
-21
configs/cyclegan_horse2zebra.yaml
configs/cyclegan_horse2zebra.yaml
+18
-19
configs/pix2pix_cityscapes.yaml
configs/pix2pix_cityscapes.yaml
+27
-20
configs/pix2pix_cityscapes_2gpus.yaml
configs/pix2pix_cityscapes_2gpus.yaml
+26
-20
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+26
-20
ppgan/datasets/paired_dataset.py
ppgan/datasets/paired_dataset.py
+14
-18
ppgan/datasets/transforms/__init__.py
ppgan/datasets/transforms/__init__.py
+1
-0
ppgan/datasets/transforms/builder.py
ppgan/datasets/transforms/builder.py
+55
-0
ppgan/datasets/transforms/transforms.py
ppgan/datasets/transforms/transforms.py
+256
-5
ppgan/datasets/unpaired_dataset.py
ppgan/datasets/unpaired_dataset.py
+21
-12
ppgan/models/builder.py
ppgan/models/builder.py
+0
-9
ppgan/models/pix2pix_model.py
ppgan/models/pix2pix_model.py
+2
-2
未找到文件。
configs/cyclegan_cityscapes.yaml
浏览文件 @
06d3c848
...
...
@@ -36,16 +36,17 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
-
name
:
RandomCrop
output_size
:
[
256
,
256
]
-
name
:
RandomHorizontalFlip
prob
:
0.5
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
test
:
name
:
SingleDataset
dataroot
:
data/cityscapes/testB
...
...
@@ -55,17 +56,13 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
...
...
configs/cyclegan_horse2zebra.yaml
浏览文件 @
06d3c848
...
...
@@ -35,16 +35,17 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
-
name
:
RandomCrop
output_size
:
[
256
,
256
]
-
name
:
RandomHorizontalFlip
prob
:
0.5
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
test
:
name
:
SingleDataset
dataroot
:
data/horse2zebra/testA
...
...
@@ -55,15 +56,13 @@ dataset:
serial_batches
:
False
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transform
:
-
name
:
Resize
size
:
[
256
,
256
]
-
name
:
Permute
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_cityscapes.yaml
浏览文件 @
06d3c848
...
...
@@ -33,16 +33,22 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/cityscapes/
...
...
@@ -53,16 +59,17 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
True
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_cityscapes_2gpus.yaml
浏览文件 @
06d3c848
...
...
@@ -32,16 +32,22 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/cityscapes/
...
...
@@ -52,16 +58,16 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
Tru
e
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transform
s
:
-
name
:
Resize
size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
Permut
e
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
configs/pix2pix_facades.yaml
浏览文件 @
06d3c848
...
...
@@ -32,16 +32,22 @@ dataset:
output_nc
:
3
serial_batches
:
False
pool_size
:
0
transform
:
load_size
:
286
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
False
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transforms
:
-
name
:
Resize
size
:
[
286
,
286
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomCrop
output_size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
prob
:
0.5
keys
:
[
image
,
image
]
-
name
:
Permute
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
test
:
name
:
PairedDataset
dataroot
:
data/facades/
...
...
@@ -52,16 +58,16 @@ dataset:
output_nc
:
3
serial_batches
:
True
pool_size
:
50
transform
:
load_size
:
256
crop_size
:
256
preprocess
:
resize_and_crop
no_flip
:
Tru
e
normalize
:
mean
:
(127.5, 127.5, 127.5)
std
:
(127.5, 127.5, 127.5)
transform
s
:
-
name
:
Resize
size
:
[
256
,
256
]
keys
:
[
image
,
image
]
-
name
:
Permut
e
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
127.5
,
127.5
,
127.5
]
std
:
[
127.5
,
127.5
,
127.5
]
keys
:
[
image
,
image
]
optimizer
:
name
:
Adam
...
...
ppgan/datasets/paired_dataset.py
浏览文件 @
06d3c848
...
...
@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform
from
.image_folder
import
make_dataset
from
.builder
import
DATASETS
from
.transforms.builder
import
build_transforms
@
DATASETS
.
register
()
class
PairedDataset
(
BaseDataset
):
"""A dataset class for paired image dataset.
"""
def
__init__
(
self
,
cfg
):
"""Initialize this dataset class.
...
...
@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset
.
__init__
(
self
,
cfg
)
self
.
dir_AB
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
)
# get the image directory
self
.
AB_paths
=
sorted
(
make_dataset
(
self
.
dir_AB
,
cfg
.
max_dataset_size
))
# get image paths
assert
(
self
.
cfg
.
transform
.
load_size
>=
self
.
cfg
.
transform
.
crop_size
)
# crop_size should be smaller than the size of loaded image
self
.
dir_AB
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
)
# get the image directory
self
.
AB_paths
=
sorted
(
make_dataset
(
self
.
dir_AB
,
cfg
.
max_dataset_size
))
# get image paths
# assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self
.
input_nc
=
self
.
cfg
.
output_nc
if
self
.
cfg
.
direction
==
'BtoA'
else
self
.
cfg
.
input_nc
self
.
output_nc
=
self
.
cfg
.
input_nc
if
self
.
cfg
.
direction
==
'BtoA'
else
self
.
cfg
.
output_nc
self
.
transforms
=
build_transforms
(
cfg
.
transforms
)
def
__getitem__
(
self
,
index
):
"""Return a data point and its metadata information.
...
...
@@ -49,27 +52,20 @@ class PairedDataset(BaseDataset):
A
=
AB
[:
h
,
:
w2
,
:]
B
=
AB
[:
h
,
w2
:,
:]
# apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
transform_params
=
get_params
(
self
.
cfg
.
transform
,
(
w2
,
h
))
#
transform_params = get_params(self.cfg.transform, (w2, h))
A_transform
=
get_transform
(
self
.
cfg
.
transform
,
transform_params
,
grayscale
=
(
self
.
input_nc
==
1
))
B_transform
=
get_transform
(
self
.
cfg
.
transform
,
transform_params
,
grayscale
=
(
self
.
output_nc
==
1
))
#
A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
#
B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
A
=
A_transform
(
A
)
B
=
B_transform
(
B
)
# A = A_transform(A)
# B = B_transform(B)
# A, B = self.transforms((A, B))
A
,
B
=
self
.
transforms
((
A
,
B
))
return
{
'A'
:
A
,
'B'
:
B
,
'A_paths'
:
AB_path
,
'B_paths'
:
AB_path
}
def
__len__
(
self
):
"""Return the total number of images in the dataset."""
return
len
(
self
.
AB_paths
)
def
get_path_by_indexs
(
self
,
indexs
):
if
isinstance
(
indexs
,
paddle
.
Variable
):
indexs
=
indexs
.
numpy
()
current_paths
=
[]
for
index
in
indexs
:
current_paths
.
append
(
self
.
AB_paths
[
index
])
return
current_paths
ppgan/datasets/transforms/__init__.py
0 → 100644
浏览文件 @
06d3c848
from
.transforms
import
RandomCrop
,
Resize
,
RandomHorizontalFlip
,
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
Normalize
,
Permute
ppgan/datasets/transforms/builder.py
0 → 100644
浏览文件 @
06d3c848
import
copy
import
traceback
import
paddle
from
...utils.registry
import
Registry
TRANSFORMS
=
Registry
(
"TRANSFORMS"
)
class
Compose
(
object
):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
data
):
for
f
in
self
.
transforms
:
try
:
# multi-fileds in a sample
# if isinstance(data, Sequence):
# data = f(*data)
# # single field in a sample, call transform directly
# else:
data
=
f
(
data
)
except
Exception
as
e
:
stack_info
=
traceback
.
format_exc
()
print
(
"fail to perform transform [{}] with error: "
"{} and stack:
\n
{}"
.
format
(
f
,
e
,
str
(
stack_info
)))
raise
e
return
data
def
build_transform
(
cfg
):
pass
def
build_transforms
(
cfg
):
transforms
=
[]
for
trans_cfg
in
cfg
:
temp_trans_cfg
=
copy
.
deepcopy
(
trans_cfg
)
name
=
temp_trans_cfg
.
pop
(
'name'
)
transforms
.
append
(
TRANSFORMS
.
get
(
name
)(
**
temp_trans_cfg
))
transforms
=
Compose
(
transforms
)
return
transforms
ppgan/datasets/transforms/transforms.py
浏览文件 @
06d3c848
import
sys
import
types
import
random
import
numbers
import
warnings
import
traceback
import
collections
import
numpy
as
np
from
paddle.utils
import
try_import
import
paddle.vision.transforms.functional
as
F
import
paddle.vision.transforms.transforms
as
T
class
RandomCrop
(
object
):
from
.builder
import
TRANSFORMS
def
__init__
(
self
,
output_size
):
if
sys
.
version_info
<
(
3
,
3
):
Sequence
=
collections
.
Sequence
Iterable
=
collections
.
Iterable
else
:
Sequence
=
collections
.
abc
.
Sequence
Iterable
=
collections
.
abc
.
Iterable
class
Transform
():
def
_set_attributes
(
self
,
args
):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if
args
:
for
k
,
v
in
args
.
items
():
# print(k, v)
if
k
!=
"self"
and
not
k
.
startswith
(
"_"
):
setattr
(
self
,
k
,
v
)
def
apply_image
(
self
,
input
):
raise
NotImplementedError
def
__call__
(
self
,
inputs
):
# print('debug:', type(inputs), type(inputs[0]))
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
if
self
.
keys
is
not
None
:
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
])
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
])
else
:
inputs
=
self
.
apply_image
(
inputs
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
Resize
(
Transform
):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def
__init__
(
self
,
size
,
interpolation
=
1
,
keys
=
None
):
super
().
__init__
()
assert
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
Iterable
)
and
len
(
size
)
==
2
)
self
.
_set_attributes
(
locals
())
if
isinstance
(
self
.
size
,
Iterable
):
self
.
size
=
tuple
(
size
)
def
apply_image
(
self
,
img
):
return
F
.
resize
(
img
,
self
.
size
,
self
.
interpolation
)
@
TRANSFORMS
.
register
()
class
RandomCrop
(
Transform
):
def
__init__
(
self
,
output_size
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
if
isinstance
(
output_size
,
int
):
self
.
output_size
=
(
output_size
,
output_size
)
else
:
...
...
@@ -19,12 +111,171 @@ class RandomCrop(object):
j
=
random
.
randint
(
0
,
w
-
tw
)
return
i
,
j
,
th
,
tw
def
__call__
(
self
,
img
):
def
apply_image
(
self
,
img
):
i
,
j
,
h
,
w
=
self
.
_get_params
(
img
)
cropped_img
=
img
[
i
:
i
+
h
,
j
:
j
+
w
]
return
cropped_img
@
TRANSFORMS
.
register
()
class
PairedRandomCrop
(
RandomCrop
):
def
__init__
(
self
,
output_size
,
keys
=
None
):
super
().
__init__
(
output_size
,
keys
)
if
isinstance
(
output_size
,
int
):
self
.
output_size
=
(
output_size
,
output_size
)
else
:
self
.
output_size
=
output_size
def
apply_image
(
self
,
img
,
crop_prams
=
None
):
if
crop_prams
is
not
None
:
i
,
j
,
h
,
w
=
crop_prams
else
:
i
,
j
,
h
,
w
=
self
.
_get_params
(
img
)
cropped_img
=
img
[
i
:
i
+
h
,
j
:
j
+
w
]
return
cropped_img
def
__call__
(
self
,
inputs
):
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
if
self
.
keys
is
not
None
:
if
isinstance
(
inputs
,
dict
):
crop_params
=
self
.
_get_params
(
inputs
[
self
.
keys
[
0
]])
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
crop_params
=
self
.
_get_params
(
inputs
[
0
])
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
],
crop_params
)
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
],
crop_params
)
else
:
crop_params
=
self
.
_get_params
(
inputs
)
inputs
=
self
.
apply_image
(
inputs
,
crop_params
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
RandomHorizontalFlip
(
Transform
):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def
__init__
(
self
,
prob
=
0.5
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
def
apply_image
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
return
F
.
flip
(
img
,
code
=
1
)
return
img
# import paddle
# paddle.vision.transforms.RandomHorizontalFlip
@
TRANSFORMS
.
register
()
class
PairedRandomHorizontalFlip
(
RandomHorizontalFlip
):
def
__init__
(
self
,
prob
=
0.5
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
def
apply_image
(
self
,
img
,
flip
):
if
flip
:
return
F
.
flip
(
img
,
code
=
1
)
return
img
def
__call__
(
self
,
inputs
):
if
isinstance
(
inputs
,
tuple
):
inputs
=
list
(
inputs
)
flip
=
np
.
random
.
random
()
<
self
.
prob
if
self
.
keys
is
not
None
:
for
i
,
key
in
enumerate
(
self
.
keys
):
if
isinstance
(
inputs
,
dict
):
inputs
[
key
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
key
],
flip
)
elif
isinstance
(
inputs
,
(
list
,
tuple
)):
inputs
[
i
]
=
getattr
(
self
,
'apply_'
+
key
)(
inputs
[
i
],
flip
)
else
:
inputs
=
self
.
apply_image
(
inputs
,
flip
)
if
isinstance
(
inputs
,
list
):
inputs
=
tuple
(
inputs
)
return
inputs
@
TRANSFORMS
.
register
()
class
Normalize
(
Transform
):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def
__init__
(
self
,
mean
=
0.0
,
std
=
1.0
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
if
isinstance
(
mean
,
numbers
.
Number
):
mean
=
[
mean
,
mean
,
mean
]
if
isinstance
(
std
,
numbers
.
Number
):
std
=
[
std
,
std
,
std
]
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
).
reshape
(
len
(
mean
),
1
,
1
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
).
reshape
(
len
(
std
),
1
,
1
)
def
apply_image
(
self
,
img
):
return
(
img
-
self
.
mean
)
/
self
.
std
@
TRANSFORMS
.
register
()
class
Permute
(
Transform
):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def
__init__
(
self
,
mode
=
"CHW"
,
to_rgb
=
True
,
keys
=
None
):
super
().
__init__
()
self
.
_set_attributes
(
locals
())
assert
mode
in
[
"CHW"
],
"Only support 'CHW' mode, but received mode: {}"
.
format
(
mode
)
self
.
mode
=
mode
self
.
to_rgb
=
to_rgb
def
apply_image
(
self
,
img
):
if
self
.
to_rgb
:
img
=
img
[...,
::
-
1
]
if
self
.
mode
==
"CHW"
:
return
img
.
transpose
((
2
,
0
,
1
))
return
img
# import paddle
# paddle.vision.transforms.Normalize
# TRANSFORMS.register(T.Normalize)
class
Crop
():
def
__init__
(
self
,
pos
,
size
):
self
.
pos
=
pos
...
...
@@ -35,6 +286,6 @@ class Crop():
x
,
y
=
self
.
pos
th
=
tw
=
self
.
size
if
(
ow
>
tw
or
oh
>
th
):
return
img
[
y
:
y
+
th
,
x
:
x
+
tw
]
return
img
[
y
:
y
+
th
,
x
:
x
+
tw
]
return
img
\ No newline at end of file
return
img
ppgan/datasets/unpaired_dataset.py
浏览文件 @
06d3c848
...
...
@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform
from
.image_folder
import
make_dataset
from
.builder
import
DATASETS
from
.transforms.builder
import
build_transforms
@
DATASETS
.
register
()
class
UnpairedDataset
(
BaseDataset
):
"""
"""
def
__init__
(
self
,
cfg
):
"""Initialize this dataset class.
...
...
@@ -19,18 +19,26 @@ class UnpairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags
"""
BaseDataset
.
__init__
(
self
,
cfg
)
self
.
dir_A
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'A'
)
# create a path '/path/to/data/trainA'
self
.
dir_B
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'B'
)
# create a path '/path/to/data/trainB'
self
.
dir_A
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'A'
)
# create a path '/path/to/data/trainA'
self
.
dir_B
=
os
.
path
.
join
(
cfg
.
dataroot
,
cfg
.
phase
+
'B'
)
# create a path '/path/to/data/trainB'
self
.
A_paths
=
sorted
(
make_dataset
(
self
.
dir_A
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainA'
self
.
B_paths
=
sorted
(
make_dataset
(
self
.
dir_B
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainB'
self
.
A_paths
=
sorted
(
make_dataset
(
self
.
dir_A
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainA'
self
.
B_paths
=
sorted
(
make_dataset
(
self
.
dir_B
,
cfg
.
max_dataset_size
))
# load images from '/path/to/data/trainB'
self
.
A_size
=
len
(
self
.
A_paths
)
# get the size of dataset A
self
.
B_size
=
len
(
self
.
B_paths
)
# get the size of dataset B
btoA
=
self
.
cfg
.
direction
==
'BtoA'
input_nc
=
self
.
cfg
.
output_nc
if
btoA
else
self
.
cfg
.
input_nc
# get the number of channels of input image
output_nc
=
self
.
cfg
.
input_nc
if
btoA
else
self
.
cfg
.
output_nc
# get the number of channels of output image
self
.
transform_A
=
get_transform
(
self
.
cfg
.
transform
,
grayscale
=
(
input_nc
==
1
))
self
.
transform_B
=
get_transform
(
self
.
cfg
.
transform
,
grayscale
=
(
output_nc
==
1
))
input_nc
=
self
.
cfg
.
output_nc
if
btoA
else
self
.
cfg
.
input_nc
# get the number of channels of input image
output_nc
=
self
.
cfg
.
input_nc
if
btoA
else
self
.
cfg
.
output_nc
# get the number of channels of output image
# self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
# self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
self
.
transform_A
=
build_transforms
(
self
.
cfg
.
transforms
)
self
.
transform_B
=
build_transforms
(
self
.
cfg
.
transforms
)
self
.
reset_paths
()
...
...
@@ -49,10 +57,11 @@ class UnpairedDataset(BaseDataset):
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path
=
self
.
A_paths
[
index
%
self
.
A_size
]
# make sure index is within then range
if
self
.
cfg
.
serial_batches
:
# make sure index is within then range
A_path
=
self
.
A_paths
[
index
%
self
.
A_size
]
# make sure index is within then range
if
self
.
cfg
.
serial_batches
:
# make sure index is within then range
index_B
=
index
%
self
.
B_size
else
:
# randomize the index for domain B to avoid fixed pairs.
else
:
# randomize the index for domain B to avoid fixed pairs.
index_B
=
random
.
randint
(
0
,
self
.
B_size
-
1
)
B_path
=
self
.
B_paths
[
index_B
]
...
...
ppgan/models/builder.py
浏览文件 @
06d3c848
...
...
@@ -2,18 +2,9 @@ import paddle
from
..utils.registry
import
Registry
MODELS
=
Registry
(
"MODEL"
)
def
build_model
(
cfg
):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model
=
MODELS
.
get
(
cfg
.
model
.
name
)(
cfg
)
return
model
# pass
\ No newline at end of file
ppgan/models/pix2pix_model.py
浏览文件 @
06d3c848
...
...
@@ -77,8 +77,8 @@ class Pix2PixModel(BaseModel):
"""
AtoB
=
self
.
opt
.
dataset
.
train
.
direction
==
'AtoB'
self
.
real_A
=
paddle
.
to_
tensor
(
input
[
'A'
if
AtoB
else
'B'
])
self
.
real_B
=
paddle
.
to_
tensor
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
real_A
=
paddle
.
to_
variable
(
input
[
'A'
if
AtoB
else
'B'
])
self
.
real_B
=
paddle
.
to_
variable
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
image_paths
=
input
[
'A_paths'
if
AtoB
else
'B_paths'
]
def
forward
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录