Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
5972bf7f
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5972bf7f
编写于
1月 13, 2021
作者:
L
LielinJiang
提交者:
GitHub
1月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add RealSR (#143)
* add realsr model
上级
89dbb63f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1205 addition
and
1 deletion
+1205
-1
configs/realsr_bicubic_noise_x4_df2k.yaml
configs/realsr_bicubic_noise_x4_df2k.yaml
+131
-0
configs/realsr_kernel_noise_x4_dped.yaml
configs/realsr_kernel_noise_x4_dped.yaml
+131
-0
data/realsr_preprocess/collect_noise.py
data/realsr_preprocess/collect_noise.py
+84
-0
data/realsr_preprocess/create_bicubic_dataset.py
data/realsr_preprocess/create_bicubic_dataset.py
+132
-0
data/realsr_preprocess/create_kernel_dataset.py
data/realsr_preprocess/create_kernel_dataset.py
+153
-0
data/realsr_preprocess/imresize.py
data/realsr_preprocess/imresize.py
+252
-0
data/realsr_preprocess/paths.yml
data/realsr_preprocess/paths.yml
+13
-0
data/realsr_preprocess/utils.py
data/realsr_preprocess/utils.py
+276
-0
ppgan/datasets/preprocess/__init__.py
ppgan/datasets/preprocess/__init__.py
+1
-1
ppgan/datasets/preprocess/transforms.py
ppgan/datasets/preprocess/transforms.py
+32
-0
未找到文件。
configs/realsr_bicubic_noise_x4_df2k.yaml
0 → 100644
浏览文件 @
5972bf7f
total_iters
:
60000
output_dir
:
output_dir
# tensor range for function tensor2img
min_max
:
(0., 1.)
model
:
name
:
ESRGAN
generator
:
name
:
RRDBNet
in_nc
:
3
out_nc
:
3
nf
:
64
nb
:
23
discriminator
:
name
:
VGGDiscriminator128
in_channels
:
3
num_feat
:
64
pixel_criterion
:
name
:
L1Loss
loss_weight
:
!!float
1e-2
perceptual_criterion
:
name
:
PerceptualLoss
layer_weights
:
'
34'
:
1.0
perceptual_weight
:
1.0
style_weight
:
0.0
norm_img
:
False
gan_criterion
:
name
:
GANLoss
gan_mode
:
vanilla
loss_weight
:
!!float
5e-3
dataset
:
train
:
name
:
SRDataset
gt_folder
:
data/realsr_preprocess/DF2K/generated/tdsr/HR_sub/
lq_folder
:
data/realsr_preprocess/DF2K/generated/tdsr/LR_sub/
num_workers
:
4
batch_size
:
16
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
128
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
-
name
:
SRNoise
noise_path
:
data/realsr_preprocess/DF2K/Corrupted_noise/
size
:
32
keys
:
[
image
]
test
:
name
:
SRDataset
gt_folder
:
data/DIV2K/val_set14/Set14
lq_folder
:
data/DIV2K/val_set14/Set14_bicLRx4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
MultiStepDecay
learning_rate
:
0.0001
milestones
:
[
5000
,
10000
,
20000
,
30000
]
gamma
:
0.5
optimizer
:
optimG
:
name
:
Adam
net_names
:
-
generator
weight_decay
:
0.0
beta1
:
0.9
beta2
:
0.999
optimD
:
name
:
Adam
net_names
:
-
discriminator
weight_decay
:
0.0
beta1
:
0.9
beta2
:
0.999
validate
:
interval
:
5000
save_img
:
false
metrics
:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
4
test_y_channel
:
false
ssim
:
name
:
SSIM
crop_border
:
4
test_y_channel
:
false
log_config
:
interval
:
100
visiual_interval
:
500
snapshot_config
:
interval
:
5000
configs/realsr_kernel_noise_x4_dped.yaml
0 → 100644
浏览文件 @
5972bf7f
total_iters
:
60000
output_dir
:
output_dir
# tensor range for function tensor2img
min_max
:
(0., 1.)
model
:
name
:
ESRGAN
generator
:
name
:
RRDBNet
in_nc
:
3
out_nc
:
3
nf
:
64
nb
:
23
discriminator
:
name
:
VGGDiscriminator128
in_channels
:
3
num_feat
:
64
pixel_criterion
:
name
:
L1Loss
loss_weight
:
!!float
1e-2
perceptual_criterion
:
name
:
PerceptualLoss
layer_weights
:
'
34'
:
1.0
perceptual_weight
:
1.0
style_weight
:
0.0
norm_img
:
False
gan_criterion
:
name
:
GANLoss
gan_mode
:
vanilla
loss_weight
:
!!float
5e-3
dataset
:
train
:
name
:
SRDataset
gt_folder
:
data/realsr_preprocess/DPED/generated/clean/train_tdsr/HR/
lq_folder
:
data/realsr_preprocess/DPED/generated/clean/train_tdsr/LR/
num_workers
:
4
batch_size
:
16
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
128
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
-
name
:
SRNoise
noise_path
:
data/realsr_preprocess/DPED/DPEDiphone_noise/
size
:
32
keys
:
[
image
]
test
:
name
:
SRDataset
gt_folder
:
data/DIV2K/val_set14/Set14
lq_folder
:
data/DIV2K/val_set14/Set14_bicLRx4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
MultiStepDecay
learning_rate
:
0.0001
milestones
:
[
5000
,
10000
,
20000
,
30000
]
gamma
:
0.5
optimizer
:
optimG
:
name
:
Adam
net_names
:
-
generator
weight_decay
:
0.0
beta1
:
0.9
beta2
:
0.999
optimD
:
name
:
Adam
net_names
:
-
discriminator
weight_decay
:
0.0
beta1
:
0.9
beta2
:
0.999
validate
:
interval
:
5000
save_img
:
false
metrics
:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
4
test_y_channel
:
false
ssim
:
name
:
SSIM
crop_border
:
4
test_y_channel
:
false
log_config
:
interval
:
100
visiual_interval
:
500
snapshot_config
:
interval
:
5000
data/realsr_preprocess/collect_noise.py
0 → 100644
浏览文件 @
5972bf7f
from
PIL
import
Image
import
numpy
as
np
import
os.path
as
osp
import
glob
import
os
import
argparse
import
yaml
parser
=
argparse
.
ArgumentParser
(
description
=
'create a dataset'
)
parser
.
add_argument
(
'--dataset'
,
default
=
'df2k'
,
type
=
str
,
help
=
'selecting different datasets'
)
parser
.
add_argument
(
'--artifacts'
,
default
=
''
,
type
=
str
,
help
=
'selecting different artifacts type'
)
parser
.
add_argument
(
'--cleanup_factor'
,
default
=
2
,
type
=
int
,
help
=
'downscaling factor for image cleanup'
)
parser
.
add_argument
(
'--upscale_factor'
,
default
=
4
,
type
=
int
,
choices
=
[
4
],
help
=
'super resolution upscale factor'
)
opt
=
parser
.
parse_args
()
# define input and target directories
with
open
(
'./preprocess/paths.yml'
,
'r'
)
as
stream
:
PATHS
=
yaml
.
load
(
stream
)
def
noise_patch
(
rgb_img
,
sp
,
max_var
,
min_mean
):
img
=
rgb_img
.
convert
(
'L'
)
rgb_img
=
np
.
array
(
rgb_img
)
img
=
np
.
array
(
img
)
w
,
h
=
img
.
shape
collect_patchs
=
[]
for
i
in
range
(
0
,
w
-
sp
,
sp
):
for
j
in
range
(
0
,
h
-
sp
,
sp
):
patch
=
img
[
i
:
i
+
sp
,
j
:
j
+
sp
]
var_global
=
np
.
var
(
patch
)
mean_global
=
np
.
mean
(
patch
)
if
var_global
<
max_var
and
mean_global
>
min_mean
:
rgb_patch
=
rgb_img
[
i
:
i
+
sp
,
j
:
j
+
sp
,
:]
collect_patchs
.
append
(
rgb_patch
)
return
collect_patchs
if
__name__
==
'__main__'
:
if
opt
.
dataset
==
'df2k'
:
img_dir
=
PATHS
[
opt
.
dataset
][
opt
.
artifacts
][
'source'
]
noise_dir
=
PATHS
[
'datasets'
][
'df2k'
]
+
'/Corrupted_noise'
sp
=
256
max_var
=
20
min_mean
=
0
else
:
img_dir
=
PATHS
[
opt
.
dataset
][
opt
.
artifacts
][
'hr'
][
'train'
]
noise_dir
=
PATHS
[
'datasets'
][
'dped'
]
+
'/DPEDiphone_noise'
sp
=
256
max_var
=
20
min_mean
=
50
assert
not
os
.
path
.
exists
(
noise_dir
)
os
.
mkdir
(
noise_dir
)
img_paths
=
sorted
(
glob
.
glob
(
osp
.
join
(
img_dir
,
'*.png'
)))
cnt
=
0
for
path
in
img_paths
:
img_name
=
osp
.
splitext
(
osp
.
basename
(
path
))[
0
]
print
(
'**********'
,
img_name
,
'**********'
)
img
=
Image
.
open
(
path
).
convert
(
'RGB'
)
patchs
=
noise_patch
(
img
,
sp
,
max_var
,
min_mean
)
for
idx
,
patch
in
enumerate
(
patchs
):
save_path
=
osp
.
join
(
noise_dir
,
'{}_{:03}.png'
.
format
(
img_name
,
idx
))
cnt
+=
1
print
(
'collect:'
,
cnt
,
save_path
)
Image
.
fromarray
(
patch
).
save
(
save_path
)
data/realsr_preprocess/create_bicubic_dataset.py
0 → 100644
浏览文件 @
5972bf7f
import
argparse
import
os
import
yaml
import
utils
from
PIL
import
Image
from
tqdm
import
tqdm
import
paddle
import
paddle.vision.transforms.functional
as
TF
paddle
.
set_device
(
'cpu'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Apply the trained model to create a dataset'
)
parser
.
add_argument
(
'--checkpoint'
,
default
=
None
,
type
=
str
,
help
=
'checkpoint model to use'
)
parser
.
add_argument
(
'--artifacts'
,
default
=
''
,
type
=
str
,
help
=
'selecting different artifacts type'
)
parser
.
add_argument
(
'--name'
,
default
=
''
,
type
=
str
,
help
=
'additional string added to folder path'
)
parser
.
add_argument
(
'--dataset'
,
default
=
'df2k'
,
type
=
str
,
help
=
'selecting different datasets'
)
parser
.
add_argument
(
'--track'
,
default
=
'train'
,
type
=
str
,
help
=
'selecting train or valid track'
)
parser
.
add_argument
(
'--num_res_blocks'
,
default
=
8
,
type
=
int
,
help
=
'number of ResNet blocks'
)
parser
.
add_argument
(
'--cleanup_factor'
,
default
=
2
,
type
=
int
,
help
=
'downscaling factor for image cleanup'
)
parser
.
add_argument
(
'--upscale_factor'
,
default
=
4
,
type
=
int
,
choices
=
[
4
],
help
=
'super resolution upscale factor'
)
opt
=
parser
.
parse_args
()
# define input and target directories
with
open
(
'./paths.yml'
,
'r'
)
as
stream
:
PATHS
=
yaml
.
load
(
stream
)
if
opt
.
dataset
==
'df2k'
:
path_sdsr
=
PATHS
[
'datasets'
][
'df2k'
]
+
'/generated/sdsr/'
path_tdsr
=
PATHS
[
'datasets'
][
'df2k'
]
+
'/generated/tdsr/'
input_source_dir
=
PATHS
[
'df2k'
][
'tdsr'
][
'source'
]
input_target_dir
=
PATHS
[
'df2k'
][
'tdsr'
][
'target'
]
source_files
=
[
os
.
path
.
join
(
input_source_dir
,
x
)
for
x
in
os
.
listdir
(
input_source_dir
)
if
utils
.
is_image_file
(
x
)
]
target_files
=
[
os
.
path
.
join
(
input_target_dir
,
x
)
for
x
in
os
.
listdir
(
input_target_dir
)
if
utils
.
is_image_file
(
x
)
]
else
:
path_sdsr
=
PATHS
[
'datasets'
][
opt
.
dataset
]
+
'/generated/'
+
opt
.
artifacts
+
'/'
+
opt
.
track
+
opt
.
name
+
'_sdsr/'
path_tdsr
=
PATHS
[
'datasets'
][
opt
.
dataset
]
+
'/generated/'
+
opt
.
artifacts
+
'/'
+
opt
.
track
+
opt
.
name
+
'_tdsr/'
input_source_dir
=
PATHS
[
opt
.
dataset
][
opt
.
artifacts
][
'hr'
][
opt
.
track
]
input_target_dir
=
None
source_files
=
[
os
.
path
.
join
(
input_source_dir
,
x
)
for
x
in
os
.
listdir
(
input_source_dir
)
if
utils
.
is_image_file
(
x
)
]
target_files
=
[]
tdsr_hr_dir
=
path_tdsr
+
'HR'
tdsr_lr_dir
=
path_tdsr
+
'LR'
assert
not
os
.
path
.
exists
(
PATHS
[
'datasets'
][
opt
.
dataset
])
if
not
os
.
path
.
exists
(
tdsr_hr_dir
):
os
.
makedirs
(
tdsr_hr_dir
)
if
not
os
.
path
.
exists
(
tdsr_lr_dir
):
os
.
makedirs
(
tdsr_lr_dir
)
# generate the noisy images
with
paddle
.
no_grad
():
for
file
in
tqdm
(
source_files
,
desc
=
'Generating images from source'
):
# load HR image
input_img
=
Image
.
open
(
file
)
input_img
=
TF
.
to_tensor
(
input_img
)
# Resize HR image to clean it up and make sure it can be resized again
resize2_img
=
utils
.
imresize
(
input_img
,
1.0
/
opt
.
cleanup_factor
,
True
)
_
,
w
,
h
=
resize2_img
.
shape
w
=
w
-
w
%
opt
.
upscale_factor
h
=
h
-
h
%
opt
.
upscale_factor
resize2_cut_img
=
resize2_img
[:,
:
w
,
:
h
]
# Save resize2_cut_img as HR image for TDSR
path
=
os
.
path
.
join
(
tdsr_hr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
resize2_cut_img
).
save
(
path
,
'PNG'
)
# Generate resize3_cut_img and apply model
resize3_cut_img
=
utils
.
imresize
(
resize2_cut_img
,
1.0
/
opt
.
upscale_factor
,
True
)
# Save resize3_cut_noisy_img as LR image for TDSR
path
=
os
.
path
.
join
(
tdsr_lr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
resize3_cut_img
).
save
(
path
,
'PNG'
)
for
file
in
tqdm
(
target_files
,
desc
=
'Generating images from target'
):
# load HR image
input_img
=
Image
.
open
(
file
)
input_img
=
TF
.
to_tensor
(
input_img
)
# Save input_img as HR image for TDSR
path
=
os
.
path
.
join
(
tdsr_hr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
input_img
).
save
(
path
,
'PNG'
)
# generate resized version of input_img
resize_img
=
utils
.
imresize
(
input_img
,
1.0
/
opt
.
upscale_factor
,
True
)
# Save resize_noisy_img as LR image for TDSR
path
=
os
.
path
.
join
(
tdsr_lr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
resize_img
).
save
(
path
,
'PNG'
)
data/realsr_preprocess/create_kernel_dataset.py
0 → 100644
浏览文件 @
5972bf7f
import
os
import
yaml
import
glob
import
utils
import
argparse
import
numpy
as
np
from
PIL
import
Image
from
tqdm
import
tqdm
from
imresize
import
imresize
from
scipy.io
import
loadmat
import
paddle
import
paddle.vision.transforms.functional
as
TF
paddle
.
set_device
(
'cpu'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Apply the trained model to create a dataset'
)
parser
.
add_argument
(
'--kernel_path'
,
default
=
'./preprocess/KernelGAN/results'
,
type
=
str
,
help
=
'kernel path to use'
)
parser
.
add_argument
(
'--artifacts'
,
default
=
''
,
type
=
str
,
help
=
'selecting different artifacts type'
)
parser
.
add_argument
(
'--name'
,
default
=
''
,
type
=
str
,
help
=
'additional string added to folder path'
)
parser
.
add_argument
(
'--dataset'
,
default
=
'df2k'
,
type
=
str
,
help
=
'selecting different datasets'
)
parser
.
add_argument
(
'--track'
,
default
=
'train'
,
type
=
str
,
help
=
'selecting train or valid track'
)
parser
.
add_argument
(
'--num_res_blocks'
,
default
=
8
,
type
=
int
,
help
=
'number of ResNet blocks'
)
parser
.
add_argument
(
'--cleanup_factor'
,
default
=
2
,
type
=
int
,
help
=
'downscaling factor for image cleanup'
)
parser
.
add_argument
(
'--upscale_factor'
,
default
=
4
,
type
=
int
,
choices
=
[
4
],
help
=
'super resolution upscale factor'
)
opt
=
parser
.
parse_args
()
# define input and target directories
with
open
(
'./paths.yml'
,
'r'
)
as
stream
:
PATHS
=
yaml
.
load
(
stream
)
if
opt
.
dataset
==
'df2k'
:
path_sdsr
=
PATHS
[
'datasets'
][
'df2k'
]
+
'/generated/sdsr/'
path_tdsr
=
PATHS
[
'datasets'
][
'df2k'
]
+
'/generated/tdsr/'
input_source_dir
=
PATHS
[
'df2k'
][
'tdsr'
][
'source'
]
input_target_dir
=
PATHS
[
'df2k'
][
'tdsr'
][
'target'
]
source_files
=
[
os
.
path
.
join
(
input_source_dir
,
x
)
for
x
in
os
.
listdir
(
input_source_dir
)
if
utils
.
is_image_file
(
x
)
]
target_files
=
[
os
.
path
.
join
(
input_target_dir
,
x
)
for
x
in
os
.
listdir
(
input_target_dir
)
if
utils
.
is_image_file
(
x
)
]
else
:
path_sdsr
=
PATHS
[
'datasets'
][
opt
.
dataset
]
+
'/generated/'
+
opt
.
artifacts
+
'/'
+
opt
.
track
+
opt
.
name
+
'_sdsr/'
path_tdsr
=
PATHS
[
'datasets'
][
opt
.
dataset
]
+
'/generated/'
+
opt
.
artifacts
+
'/'
+
opt
.
track
+
opt
.
name
+
'_tdsr/'
input_source_dir
=
PATHS
[
opt
.
dataset
][
opt
.
artifacts
][
'hr'
][
opt
.
track
]
input_target_dir
=
None
source_files
=
[
os
.
path
.
join
(
input_source_dir
,
x
)
for
x
in
os
.
listdir
(
input_source_dir
)
if
utils
.
is_image_file
(
x
)
]
target_files
=
[]
tdsr_hr_dir
=
path_tdsr
+
'HR'
tdsr_lr_dir
=
path_tdsr
+
'LR'
assert
not
os
.
path
.
exists
(
PATHS
[
'datasets'
][
opt
.
dataset
])
if
not
os
.
path
.
exists
(
tdsr_hr_dir
):
os
.
makedirs
(
tdsr_hr_dir
)
if
not
os
.
path
.
exists
(
tdsr_lr_dir
):
os
.
makedirs
(
tdsr_lr_dir
)
kernel_paths
=
glob
.
glob
(
os
.
path
.
join
(
opt
.
kernel_path
,
'*/*_kernel_x4.mat'
))
kernel_num
=
len
(
kernel_paths
)
print
(
'kernel_num: '
,
kernel_num
)
# generate the noisy images
with
paddle
.
no_grad
():
for
file
in
tqdm
(
source_files
,
desc
=
'Generating images from source'
):
# load HR image
input_img
=
Image
.
open
(
file
)
input_img
=
TF
.
to_tensor
(
input_img
)
# Resize HR image to clean it up and make sure it can be resized again
resize2_img
=
utils
.
imresize
(
input_img
,
1.0
/
opt
.
cleanup_factor
,
True
)
_
,
w
,
h
=
resize2_img
.
shape
w
=
w
-
w
%
opt
.
upscale_factor
h
=
h
-
h
%
opt
.
upscale_factor
resize2_cut_img
=
resize2_img
[:,
:
w
,
:
h
]
# Save resize2_cut_img as HR image for TDSR
path
=
os
.
path
.
join
(
tdsr_hr_dir
,
os
.
path
.
basename
(
file
))
resize2_cut_img
=
utils
.
to_pil_image
(
resize2_cut_img
)
resize2_cut_img
.
save
(
path
,
'PNG'
)
# Generate resize3_cut_img and apply model
kernel_path
=
kernel_paths
[
np
.
random
.
randint
(
0
,
kernel_num
)]
mat
=
loadmat
(
kernel_path
)
k
=
np
.
array
([
mat
[
'Kernel'
]]).
squeeze
()
resize3_cut_img
=
imresize
(
np
.
array
(
resize2_cut_img
),
scale_factor
=
1.0
/
opt
.
upscale_factor
,
kernel
=
k
)
# Save resize3_cut_img as LR image for TDSR
path
=
os
.
path
.
join
(
tdsr_lr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
resize3_cut_img
).
save
(
path
,
'PNG'
)
for
file
in
tqdm
(
target_files
,
desc
=
'Generating images from target'
):
# load HR image
input_img
=
Image
.
open
(
file
)
input_img
=
TF
.
to_tensor
(
input_img
)
# Save input_img as HR image for TDSR
path
=
os
.
path
.
join
(
tdsr_hr_dir
,
os
.
path
.
basename
(
file
))
HR_img
=
utils
.
to_pil_image
(
input_img
)
HR_img
.
save
(
path
,
'PNG'
)
# generate resized version of input_img
kernel_path
=
kernel_paths
[
np
.
random
.
randint
(
0
,
kernel_num
)]
mat
=
loadmat
(
kernel_path
)
k
=
np
.
array
([
mat
[
'Kernel'
]]).
squeeze
()
resize_img
=
imresize
(
np
.
array
(
HR_img
),
scale_factor
=
1.0
/
opt
.
upscale_factor
,
kernel
=
k
)
# Save resize_noisy_img as LR image for TDSR
path
=
os
.
path
.
join
(
tdsr_lr_dir
,
os
.
path
.
basename
(
file
))
utils
.
to_pil_image
(
resize_img
).
save
(
path
,
'PNG'
)
data/realsr_preprocess/imresize.py
0 → 100644
浏览文件 @
5972bf7f
# reference from kernelgan
import
numpy
as
np
from
scipy.ndimage
import
filters
,
measurements
,
interpolation
from
math
import
pi
def
imresize
(
im
,
scale_factor
=
None
,
output_shape
=
None
,
kernel
=
None
,
antialiasing
=
True
,
kernel_shift_flag
=
False
):
# First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
scale_factor
,
output_shape
=
fix_scale_and_size
(
im
.
shape
,
output_shape
,
scale_factor
)
# For a given numeric kernel case, just do convolution and sub-sampling (downscaling only)
if
type
(
kernel
)
==
np
.
ndarray
and
scale_factor
[
0
]
<=
1
:
return
numeric_kernel
(
im
,
kernel
,
scale_factor
,
output_shape
,
kernel_shift_flag
)
# Choose interpolation method, each method has the matching kernel size
method
,
kernel_width
=
{
"cubic"
:
(
cubic
,
4.0
),
"lanczos2"
:
(
lanczos2
,
4.0
),
"lanczos3"
:
(
lanczos3
,
6.0
),
"box"
:
(
box
,
1.0
),
"linear"
:
(
linear
,
2.0
),
None
:
(
cubic
,
4.0
)
# Default interpolation method is cubic
}.
get
(
kernel
)
# Antialiasing is only used when downscaling
antialiasing
*=
(
scale_factor
[
0
]
<
1
)
# Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
sorted_dims
=
np
.
argsort
(
np
.
array
(
scale_factor
)).
tolist
()
# Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
out_im
=
np
.
copy
(
im
)
for
dim
in
sorted_dims
:
# No point doing calculations for scale-factor 1. nothing will happen anyway
if
scale_factor
[
dim
]
==
1.0
:
continue
# for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
# weights that multiply the values there to get its result.
weights
,
field_of_view
=
contributions
(
im
.
shape
[
dim
],
output_shape
[
dim
],
scale_factor
[
dim
],
method
,
kernel_width
,
antialiasing
)
# Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
out_im
=
resize_along_dim
(
out_im
,
dim
,
weights
,
field_of_view
)
return
out_im
def
fix_scale_and_size
(
input_shape
,
output_shape
,
scale_factor
):
# First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
# same size as the number of input dimensions)
if
scale_factor
is
not
None
:
# By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
if
np
.
isscalar
(
scale_factor
):
scale_factor
=
[
scale_factor
,
scale_factor
]
# We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
scale_factor
=
list
(
scale_factor
)
scale_factor
.
extend
([
1
]
*
(
len
(
input_shape
)
-
len
(
scale_factor
)))
# Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
# to all the unspecified dimensions
if
output_shape
is
not
None
:
output_shape
=
list
(
np
.
uint
(
np
.
array
(
output_shape
)))
+
list
(
input_shape
[
len
(
output_shape
):])
# Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
# sub-optimal, because there can be different scales to the same output-shape.
if
scale_factor
is
None
:
scale_factor
=
1.0
*
np
.
array
(
output_shape
)
/
np
.
array
(
input_shape
)
# Dealing with missing output-shape. calculating according to scale-factor
if
output_shape
is
None
:
output_shape
=
np
.
uint
(
np
.
ceil
(
np
.
array
(
input_shape
)
*
np
.
array
(
scale_factor
)))
return
scale_factor
,
output_shape
def
contributions
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
# This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
# such that each position from the field_of_view will be multiplied with a matching filter from the
# 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
# around it. This is only done for one dimension of the image.
# When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
# 1/sf. this means filtering is more 'low-pass filter'.
fixed_kernel
=
(
lambda
arg
:
scale
*
kernel
(
scale
*
arg
))
if
antialiasing
else
kernel
kernel_width
*=
1.0
/
scale
if
antialiasing
else
1.0
# These are the coordinates of the output image
out_coordinates
=
np
.
arange
(
1
,
out_length
+
1
)
# These are the matching positions of the output-coordinates on the input image coordinates.
# Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
# [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
# The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
# the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
# one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
# So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
# at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
# (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
match_coordinates
=
1.0
*
out_coordinates
/
scale
+
0.5
*
(
1
-
1.0
/
scale
)
# This is the left boundary to start multiplying the filter from, it depends on the size of the filter
left_boundary
=
np
.
floor
(
match_coordinates
-
kernel_width
/
2
)
# Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
# of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
expanded_kernel_width
=
np
.
ceil
(
kernel_width
)
+
2
# Determine a set of field_of_view for each each output position, these are the pixels in the input image
# that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
# vertical dim is the pixels it 'sees' (kernel_size + 2)
field_of_view
=
np
.
squeeze
(
np
.
uint
(
np
.
expand_dims
(
left_boundary
,
axis
=
1
)
+
np
.
arange
(
expanded_kernel_width
)
-
1
))
# Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
# vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
# 'field_of_view')
weights
=
fixed_kernel
(
1.0
*
np
.
expand_dims
(
match_coordinates
,
axis
=
1
)
-
field_of_view
-
1
)
# Normalize weights to sum up to 1. be careful from dividing by 0
sum_weights
=
np
.
sum
(
weights
,
axis
=
1
)
sum_weights
[
sum_weights
==
0
]
=
1.0
weights
=
1.0
*
weights
/
np
.
expand_dims
(
sum_weights
,
axis
=
1
)
# We use this mirror structure as a trick for reflection padding at the boundaries
mirror
=
np
.
uint
(
np
.
concatenate
(
(
np
.
arange
(
in_length
),
np
.
arange
(
in_length
-
1
,
-
1
,
step
=-
1
))))
field_of_view
=
mirror
[
np
.
mod
(
field_of_view
,
mirror
.
shape
[
0
])]
# Get rid of weights and pixel positions that are of zero weight
non_zero_out_pixels
=
np
.
nonzero
(
np
.
any
(
weights
,
axis
=
0
))
weights
=
np
.
squeeze
(
weights
[:,
non_zero_out_pixels
])
field_of_view
=
np
.
squeeze
(
field_of_view
[:,
non_zero_out_pixels
])
# Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
return
weights
,
field_of_view
def
resize_along_dim
(
im
,
dim
,
weights
,
field_of_view
):
# To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
tmp_im
=
np
.
swapaxes
(
im
,
dim
,
0
)
# We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
# tmp_im[field_of_view.T], (bsxfun style)
weights
=
np
.
reshape
(
weights
.
T
,
list
(
weights
.
T
.
shape
)
+
(
np
.
ndim
(
im
)
-
1
)
*
[
1
])
# This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1.
# for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
# only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with
# the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
# matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
# same number
tmp_out_im
=
np
.
sum
(
tmp_im
[
field_of_view
.
T
]
*
weights
,
axis
=
0
)
# Finally we swap back the axes to the original order
return
np
.
swapaxes
(
tmp_out_im
,
dim
,
0
)
def
numeric_kernel
(
im
,
kernel
,
scale_factor
,
output_shape
,
kernel_shift_flag
):
# See kernel_shift function to understand what this is
if
kernel_shift_flag
:
kernel
=
kernel_shift
(
kernel
,
scale_factor
)
# First run a correlation (convolution with flipped kernel)
out_im
=
np
.
zeros_like
(
im
)
for
channel
in
range
(
np
.
ndim
(
im
)):
out_im
[:,
:,
channel
]
=
filters
.
correlate
(
im
[:,
:,
channel
],
kernel
)
# Then subsample and return
return
out_im
[
np
.
round
(
np
.
linspace
(
0
,
im
.
shape
[
0
]
-
1
/
scale_factor
[
0
],
output_shape
[
0
])).
astype
(
int
)[:,
None
],
np
.
round
(
np
.
linspace
(
0
,
im
.
shape
[
1
]
-
1
/
scale_factor
[
1
],
output_shape
[
1
])).
astype
(
int
),
:]
def
kernel_shift
(
kernel
,
sf
):
# There are two reasons for shifting the kernel:
# 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know
# the degradation process included shifting so we always assume center of mass is center of the kernel.
# 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first
# pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the
# top left corner of the first pixel. that is why different shift size needed between od and even size.
# Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows:
# The input image, when interpolated (regular bicubic) is exactly aligned with ground truth.
# First calculate the current center of mass for the kernel
current_center_of_mass
=
measurements
.
center_of_mass
(
kernel
)
# The second ("+ 0.5 * ....") is for applying condition 2 from the comments above
wanted_center_of_mass
=
np
.
array
(
kernel
.
shape
)
//
2
+
0.5
*
(
sf
-
(
kernel
.
shape
[
0
]
%
2
))
# wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (np.array(sf)[0:2] - (kernel.shape[0] % 2))
# Define the shift vector for the kernel shifting (x,y)
shift_vec
=
wanted_center_of_mass
-
current_center_of_mass
# Before applying the shift, we first pad the kernel so that nothing is lost due to the shift
# (biggest shift among dims + 1 for safety)
kernel
=
np
.
pad
(
kernel
,
np
.
int
(
np
.
ceil
(
np
.
max
(
shift_vec
)))
+
1
,
'constant'
)
# Finally shift the kernel and return
return
interpolation
.
shift
(
kernel
,
shift_vec
)
# These next functions are all interpolation methods. x is the distance from the left pixel center
def
cubic
(
x
):
absx
=
np
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
return
((
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
(
absx
<=
1
)
+
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
((
1
<
absx
)
&
(
absx
<=
2
)))
def
lanczos2
(
x
):
return
(((
np
.
sin
(
pi
*
x
)
*
np
.
sin
(
pi
*
x
/
2
)
+
np
.
finfo
(
np
.
float32
).
eps
)
/
((
pi
**
2
*
x
**
2
/
2
)
+
np
.
finfo
(
np
.
float32
).
eps
))
*
(
abs
(
x
)
<
2
))
def
box
(
x
):
return
((
-
0.5
<=
x
)
&
(
x
<
0.5
))
*
1.0
def
lanczos3
(
x
):
return
(((
np
.
sin
(
pi
*
x
)
*
np
.
sin
(
pi
*
x
/
3
)
+
np
.
finfo
(
np
.
float32
).
eps
)
/
((
pi
**
2
*
x
**
2
/
3
)
+
np
.
finfo
(
np
.
float32
).
eps
))
*
(
abs
(
x
)
<
3
))
def
linear
(
x
):
return
(
x
+
1
)
*
((
-
1
<=
x
)
&
(
x
<
0
))
+
(
1
-
x
)
*
((
0
<=
x
)
&
(
x
<=
1
))
data/realsr_preprocess/paths.yml
0 → 100644
浏览文件 @
5972bf7f
df2k
:
tdsr
:
source
:
'
/workspace/datasets/ntire20/Corrupted-tr-x'
target
:
'
/workspace/datasets/ntire20/Corrupted-tr-y'
valid
:
dped
:
clean
:
hr
:
train
:
'
/workspace/datasets/ntire20/DPEDiphone-tr-x'
valid
:
'
/workspace/datasets/ntire20/DPEDiphone-va'
datasets
:
df2k
:
'
DF2K'
dped
:
'
DPED'
data/realsr_preprocess/utils.py
0 → 100644
浏览文件 @
5972bf7f
import
math
import
numpy
as
np
from
PIL
import
Image
import
paddle
# set random seed for reproducibility
np
.
random
.
seed
(
0
)
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
[
'.png'
,
'.jpg'
,
'.jpeg'
,
'.PNG'
,
'.JPG'
,
'.JPEG'
])
def
calculate_valid_crop_size
(
crop_size
,
upscale_factor
):
return
crop_size
-
(
crop_size
%
upscale_factor
)
def
gaussian_noise
(
image
,
std_dev
):
noise
=
np
.
rint
(
np
.
random
.
normal
(
loc
=
0.0
,
scale
=
std_dev
,
size
=
np
.
shape
(
image
)))
return
Image
.
fromarray
(
np
.
clip
(
image
+
noise
,
0
,
255
).
astype
(
np
.
uint8
))
#################################################################################
# MATLAB imresize taken from ESRGAN (https://github.com/xinntao/BasicSR)
#################################################################################
def
cubic
(
x
):
absx
=
paddle
.
abs
(
x
)
absx2
=
absx
**
2
absx3
=
absx
**
3
temp1
=
paddle
.
cast
((
absx
<=
1
),
absx
.
dtype
)
temp2
=
paddle
.
cast
((
absx
>
1
),
absx
.
dtype
)
*
paddle
.
cast
(
(
absx
<=
2
),
absx
.
dtype
)
return
(
1.5
*
absx3
-
2.5
*
absx2
+
1
)
*
temp1
+
(
-
0.5
*
absx3
+
2.5
*
absx2
-
4
*
absx
+
2
)
*
temp2
def
calculate_weights_indices
(
in_length
,
out_length
,
scale
,
kernel
,
kernel_width
,
antialiasing
):
if
(
scale
<
1
)
and
(
antialiasing
):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width
=
kernel_width
/
scale
# Output-space coordinates
x
=
paddle
.
linspace
(
1
,
out_length
,
out_length
)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u
=
x
/
scale
+
0.5
*
(
1
-
1
/
scale
)
# What is the left-most pixel that can be involved in the computation?
left
=
paddle
.
floor
(
u
-
kernel_width
/
2
)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P
=
math
.
ceil
(
kernel_width
)
+
2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices
=
left
.
reshape
([
out_length
,
1
]).
expand
([
out_length
,
P
])
+
paddle
.
linspace
(
0
,
P
-
1
,
P
).
reshape
([
1
,
P
]).
expand
([
out_length
,
P
])
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center
=
u
.
reshape
([
out_length
,
1
]).
expand
([
out_length
,
P
])
-
indices
# apply cubic kernel
if
(
scale
<
1
)
and
(
antialiasing
):
weights
=
scale
*
cubic
(
distance_to_center
*
scale
)
else
:
weights
=
cubic
(
distance_to_center
)
# Normalize the weights matrix so that each row sums to 1.
weights_sum
=
paddle
.
sum
(
weights
,
1
).
reshape
([
out_length
,
1
])
weights
=
weights
/
weights_sum
.
expand
([
out_length
,
P
])
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp
=
np
.
sum
((
weights
.
numpy
()
==
0
),
0
)
if
not
math
.
isclose
(
weights_zero_tmp
[
0
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
[:,
1
:
1
+
P
-
2
]
weights
=
weights
[:,
1
:
1
+
P
-
2
]
if
not
math
.
isclose
(
weights_zero_tmp
[
-
1
],
0
,
rel_tol
=
1e-6
):
indices
=
indices
[:,
0
:
P
-
2
]
weights
=
weights
[:,
0
:
P
-
2
]
sym_len_s
=
-
indices
.
min
()
+
1
sym_len_e
=
indices
.
max
()
-
in_length
indices
=
indices
+
sym_len_s
-
1
return
weights
,
indices
,
int
(
sym_len_s
),
int
(
sym_len_e
)
def
imresize
(
img
,
scale
,
antialiasing
=
True
):
# Now the scale should be the same for H and W
# input: img: CHW RGB [0,1]
# output: CHW RGB [0,1] w/o round
in_C
,
in_H
,
in_W
=
img
.
shape
_
,
out_H
,
out_W
=
in_C
,
math
.
ceil
(
in_H
*
scale
),
math
.
ceil
(
in_W
*
scale
)
kernel_width
=
4
kernel
=
'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H
,
indices_H
,
sym_len_Hs
,
sym_len_He
=
calculate_weights_indices
(
in_H
,
out_H
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
weights_W
,
indices_W
,
sym_len_Ws
,
sym_len_We
=
calculate_weights_indices
(
in_W
,
out_W
,
scale
,
kernel
,
kernel_width
,
antialiasing
)
# process H dimension
# symmetric copying
img_aug
=
paddle
.
zeros
([
in_C
,
in_H
+
sym_len_Hs
+
sym_len_He
,
in_W
])
img_aug
[:,
sym_len_Hs
:
sym_len_Hs
+
in_H
,
:]
=
img
sym_patch
=
img
[:,
:
sym_len_Hs
,
:]
inv_idx
=
paddle
.
arange
(
sym_patch
.
shape
[
1
]
-
1
,
-
1
,
-
1
)
sym_patch_inv
=
paddle
.
index_select
(
sym_patch
,
inv_idx
,
1
)
img_aug
[:,
:
sym_len_Hs
,
:]
=
sym_patch_inv
sym_patch
=
img
[:,
-
sym_len_He
:,
:]
inv_idx
=
paddle
.
arange
(
sym_patch
.
shape
[
1
]
-
1
,
-
1
,
-
1
)
sym_patch_inv
=
paddle
.
index_select
(
sym_patch
,
inv_idx
,
1
)
img_aug
[:,
sym_len_Hs
+
in_H
:
sym_len_Hs
+
in_H
+
sym_len_He
,
:]
=
sym_patch_inv
out_1
=
paddle
.
zeros
([
in_C
,
out_H
,
in_W
])
kernel_width
=
weights_H
.
shape
[
1
]
for
i
in
range
(
out_H
):
idx
=
int
(
indices_H
[
i
][
0
])
out_1
[
0
,
i
,
:]
=
paddle
.
mv
(
img_aug
[
0
,
idx
:
idx
+
kernel_width
,
:].
transpose
([
1
,
0
]),
(
weights_H
[
i
]))
out_1
[
1
,
i
,
:]
=
paddle
.
mv
(
img_aug
[
1
,
idx
:
idx
+
kernel_width
,
:].
transpose
([
1
,
0
]),
(
weights_H
[
i
]))
out_1
[
2
,
i
,
:]
=
paddle
.
mv
(
img_aug
[
2
,
idx
:
idx
+
kernel_width
,
:].
transpose
([
1
,
0
]),
(
weights_H
[
i
]))
# process W dimension
# symmetric copying
out_1_aug
=
paddle
.
zeros
([
in_C
,
out_H
,
in_W
+
sym_len_Ws
+
sym_len_We
])
out_1_aug
[:,
:,
sym_len_Ws
:
sym_len_Ws
+
in_W
]
=
out_1
sym_patch
=
out_1
[:,
:,
:
sym_len_Ws
]
inv_idx
=
paddle
.
arange
(
sym_patch
.
shape
[
2
]
-
1
,
-
1
,
-
1
)
sym_patch_inv
=
paddle
.
index_select
(
sym_patch
,
inv_idx
,
2
)
out_1_aug
[:,
:,
0
:
sym_len_Ws
]
=
sym_patch_inv
sym_patch
=
out_1
[:,
:,
-
sym_len_We
:]
inv_idx
=
paddle
.
arange
(
sym_patch
.
shape
[
2
]
-
1
,
-
1
,
-
1
)
sym_patch_inv
=
paddle
.
index_select
(
sym_patch
,
inv_idx
,
2
)
out_1_aug
[:,
:,
sym_len_Ws
+
in_W
:
sym_len_Ws
+
in_W
+
sym_len_We
]
=
sym_patch_inv
out_2
=
paddle
.
zeros
([
in_C
,
out_H
,
out_W
])
kernel_width
=
weights_W
.
shape
[
1
]
for
i
in
range
(
out_W
):
idx
=
int
(
indices_W
[
i
][
0
])
out_2
[
0
,
:,
i
]
=
out_1_aug
[
0
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
out_2
[
1
,
:,
i
]
=
out_1_aug
[
1
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
out_2
[
2
,
:,
i
]
=
out_1_aug
[
2
,
:,
idx
:
idx
+
kernel_width
].
mv
(
weights_W
[
i
])
return
paddle
.
clip
(
out_2
,
0
,
1
)
def
to_pil_image
(
pic
,
mode
=
None
):
"""Convert a tensor or an ndarray to PIL Image.
Args:
pic (paddle.Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if
not
(
isinstance
(
pic
,
paddle
.
Tensor
)
or
isinstance
(
pic
,
np
.
ndarray
)):
raise
TypeError
(
'pic should be Tensor or ndarray. Got {}.'
.
format
(
type
(
pic
)))
elif
isinstance
(
pic
,
paddle
.
Tensor
):
if
len
(
pic
.
shape
)
not
in
{
2
,
3
}:
raise
ValueError
(
'pic should be 2/3 dimensional. Got {} dimensions.'
.
format
(
pic
.
ndimension
()))
elif
len
(
pic
.
shape
)
==
2
:
# if 2D image, add channel dimension (CHW)
pic
=
pic
.
unsqueeze
(
0
)
elif
isinstance
(
pic
,
np
.
ndarray
):
if
pic
.
ndim
not
in
{
2
,
3
}:
raise
ValueError
(
'pic should be 2/3 dimensional. Got {} dimensions.'
.
format
(
pic
.
ndim
))
elif
pic
.
ndim
==
2
:
# if 2D image, add channel dimension (HWC)
pic
=
np
.
expand_dims
(
pic
,
2
)
npimg
=
pic
if
isinstance
(
pic
,
paddle
.
Tensor
)
and
mode
!=
'F'
:
pic
=
pic
.
numpy
()
if
pic
.
dtype
==
'float32'
:
npimg
=
np
.
transpose
((
pic
*
255.
).
astype
(
'uint8'
),
(
1
,
2
,
0
))
if
not
isinstance
(
npimg
,
np
.
ndarray
):
raise
TypeError
(
'Input pic must be a torch.Tensor or NumPy ndarray, '
+
'not {}'
.
format
(
type
(
npimg
)))
if
npimg
.
shape
[
2
]
==
1
:
expected_mode
=
None
npimg
=
npimg
[:,
:,
0
]
if
npimg
.
dtype
==
np
.
uint8
:
expected_mode
=
'L'
elif
npimg
.
dtype
==
np
.
int16
:
expected_mode
=
'I;16'
elif
npimg
.
dtype
==
np
.
int32
:
expected_mode
=
'I'
elif
npimg
.
dtype
==
np
.
float32
:
expected_mode
=
'F'
if
mode
is
not
None
and
mode
!=
expected_mode
:
raise
ValueError
(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.
format
(
mode
,
np
.
dtype
,
expected_mode
))
mode
=
expected_mode
elif
npimg
.
shape
[
2
]
==
2
:
permitted_2_channel_modes
=
[
'LA'
]
if
mode
is
not
None
and
mode
not
in
permitted_2_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 2D inputs"
.
format
(
permitted_2_channel_modes
))
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
'LA'
elif
npimg
.
shape
[
2
]
==
4
:
permitted_4_channel_modes
=
[
'RGBA'
,
'CMYK'
,
'RGBX'
]
if
mode
is
not
None
and
mode
not
in
permitted_4_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 4D inputs"
.
format
(
permitted_4_channel_modes
))
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
'RGBA'
else
:
permitted_3_channel_modes
=
[
'RGB'
,
'YCbCr'
,
'HSV'
]
if
mode
is
not
None
and
mode
not
in
permitted_3_channel_modes
:
raise
ValueError
(
"Only modes {} are supported for 3D inputs"
.
format
(
permitted_3_channel_modes
))
if
mode
is
None
and
npimg
.
dtype
==
np
.
uint8
:
mode
=
'RGB'
if
mode
is
None
:
raise
TypeError
(
'Input type {} is not supported'
.
format
(
npimg
.
dtype
))
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
ppgan/datasets/preprocess/__init__.py
浏览文件 @
5972bf7f
from
.io
import
LoadImageFromFile
from
.transforms
import
(
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
PairedRandomVerticalFlip
,
PairedRandomTransposeHW
,
SRPairedRandomCrop
,
SplitPairedImage
)
SRPairedRandomCrop
,
SplitPairedImage
,
SRNoise
)
from
.builder
import
build_preprocess
ppgan/datasets/preprocess/transforms.py
浏览文件 @
5972bf7f
...
...
@@ -14,9 +14,13 @@
import
sys
import
cv2
import
glob
import
random
import
numbers
import
collections
import
numpy
as
np
from
PIL
import
Image
import
paddle.vision.transforms
as
T
import
paddle.vision.transforms.functional
as
F
...
...
@@ -230,3 +234,31 @@ class SRPairedRandomCrop(T.BaseTransform):
outputs
=
(
lq
,
gt
)
return
outputs
@
TRANSFORMS
.
register
()
class
SRNoise
(
T
.
BaseTransform
):
"""Super resolution noise.
Args:
noise_path (str): directory of noise image.
size (int): cropped noise patch size.
"""
def
__init__
(
self
,
noise_path
,
size
,
keys
=
None
):
self
.
noise_path
=
noise_path
self
.
noise_imgs
=
sorted
(
glob
.
glob
(
noise_path
+
'*.png'
))
self
.
size
=
size
self
.
keys
=
keys
self
.
transform
=
T
.
Compose
([
T
.
RandomCrop
(
size
),
T
.
Transpose
(),
T
.
Normalize
([
0.
,
0.
,
0.
],
[
255.
,
255.
,
255.
])
])
def
_apply_image
(
self
,
image
):
idx
=
np
.
random
.
randint
(
0
,
len
(
self
.
noise_imgs
))
noise
=
self
.
transform
(
Image
.
open
(
self
.
noise_imgs
[
idx
]))
normed_noise
=
noise
-
np
.
mean
(
noise
,
axis
=
(
1
,
2
),
keepdims
=
True
)
image
=
image
+
normed_noise
image
=
np
.
clip
(
image
,
0.
,
1.
)
return
image
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录