Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
2aaef2e9
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2aaef2e9
编写于
9月 14, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add srmodel
上级
54896a27
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
72 addition
and
5 deletion
+72
-5
ppgan/datasets/__init__.py
ppgan/datasets/__init__.py
+1
-0
ppgan/datasets/base_dataset.py
ppgan/datasets/base_dataset.py
+4
-1
ppgan/datasets/builder.py
ppgan/datasets/builder.py
+4
-0
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+56
-2
ppgan/models/__init__.py
ppgan/models/__init__.py
+2
-0
ppgan/models/generators/__init__.py
ppgan/models/generators/__init__.py
+2
-1
ppgan/utils/visual.py
ppgan/utils/visual.py
+3
-1
未找到文件。
ppgan/datasets/__init__.py
浏览文件 @
2aaef2e9
from
.unpaired_dataset
import
UnpairedDataset
from
.single_dataset
import
SingleDataset
from
.paired_dataset
import
PairedDataset
from
.sr_image_dataset
import
SRImageDataset
\ No newline at end of file
ppgan/datasets/base_dataset.py
浏览文件 @
2aaef2e9
...
...
@@ -95,6 +95,9 @@ def get_transform(cfg,
if
convert
:
transform_list
+=
[
transforms
.
Permute
(
to_rgb
=
True
)]
transform_list
+=
[
transforms
.
Normalize
((
127.5
,
127.5
,
127.5
),
(
127.5
,
127.5
,
127.5
))
transforms
.
Normalize
((
0.
,
0.
,
0.
),
(
255.
,
255.
,
255.
))
]
# transform_list += [
# transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
# ]
return
transforms
.
Compose
(
transform_list
)
ppgan/datasets/builder.py
浏览文件 @
2aaef2e9
...
...
@@ -111,4 +111,8 @@ def build_dataloader(cfg, is_train=True):
dataloader
=
DictDataLoader
(
dataset
,
batch_size
,
is_train
,
num_workers
)
# for i, item in enumerate(dataloader):
# print(i, item.keys())
# # break
# print('dataset build success!')
return
dataloader
ppgan/engine/trainer.py
浏览文件 @
2aaef2e9
import
os
import
time
import
copy
import
logging
import
paddle
...
...
@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader
from
..models.builder
import
build_model
from
..utils.visual
import
tensor2img
,
save_image
from
..utils.filesystem
import
save
,
load
,
makedirs
from
..metric.psnr_ssim
import
calculate_psnr
,
calculate_ssim
class
Trainer
:
def
__init__
(
self
,
cfg
):
...
...
@@ -45,9 +46,11 @@ class Trainer:
# time count
self
.
time_count
=
{}
self
.
best_metric
=
{}
def
distributed_data_parallel
(
self
):
strategy
=
paddle
.
prepare_context
()
strategy
=
paddle
.
distributed
.
prepare_context
()
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
...
...
@@ -78,11 +81,61 @@ class Trainer:
step_start_time
=
time
.
time
()
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
validate
()
self
.
model
.
lr_scheduler
.
step
()
if
epoch
%
self
.
weight_interval
==
0
:
self
.
save
(
epoch
,
'weight'
,
keep
=-
1
)
self
.
save
(
epoch
)
def
validate
(
self
):
if
not
hasattr
(
self
,
'val_dataloader'
):
self
.
val_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
val
,
is_train
=
False
)
metric_result
=
{}
for
i
,
data
in
enumerate
(
self
.
val_dataloader
):
self
.
batch_id
=
i
self
.
model
.
set_input
(
data
)
self
.
model
.
test
()
visual_results
=
{}
current_paths
=
self
.
model
.
get_image_paths
()
current_visuals
=
self
.
model
.
get_current_visuals
()
# print('debug1:', self.cfg.validate.metrics)
for
j
in
range
(
len
(
current_paths
)):
short_path
=
os
.
path
.
basename
(
current_paths
[
j
])
basename
=
os
.
path
.
splitext
(
short_path
)[
0
]
for
k
,
img_tensor
in
current_visuals
.
items
():
name
=
'%s_%s'
%
(
basename
,
k
)
visual_results
.
update
({
name
:
img_tensor
[
j
]})
# print('debug2:', self.cfg.validate.metrics)
if
'psnr'
in
self
.
cfg
.
validate
.
metrics
:
# args = copy.deepcopy(self.cfg.validate.metrics.pnsr)
# args.pop('name')
if
'psnr'
not
in
metric_result
:
metric_result
[
'psnr'
]
=
calculate_psnr
(
tensor2img
(
current_visuals
[
'output'
][
j
]),
tensor2img
(
current_visuals
[
'gt'
][
j
]),
**
self
.
cfg
.
validate
.
metrics
.
psnr
)
else
:
metric_result
[
'psnr'
]
+=
calculate_psnr
(
tensor2img
(
current_visuals
[
'output'
][
j
]),
tensor2img
(
current_visuals
[
'gt'
][
j
]),
**
self
.
cfg
.
validate
.
metrics
.
psnr
)
if
'ssim'
in
self
.
cfg
.
validate
.
metrics
:
if
'ssim'
not
in
metric_result
:
metric_result
[
'ssim'
]
=
calculate_ssim
(
tensor2img
(
current_visuals
[
'output'
][
j
]),
tensor2img
(
current_visuals
[
'gt'
][
j
]),
**
self
.
cfg
.
validate
.
metrics
.
ssim
)
else
:
metric_result
[
'ssim'
]
+=
calculate_ssim
(
tensor2img
(
current_visuals
[
'output'
][
j
]),
tensor2img
(
current_visuals
[
'gt'
][
j
]),
**
self
.
cfg
.
validate
.
metrics
.
ssim
)
self
.
visual
(
'visual_val'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
'val iter: [%d/%d]'
%
(
i
,
len
(
self
.
val_dataloader
)))
for
metric_name
in
metric_result
.
keys
():
metric_result
[
metric_name
]
/=
len
(
self
.
val_dataloader
.
dataset
)
self
.
logger
.
info
(
'Epoch {} validate end: {}'
.
format
(
self
.
current_epoch
,
metric_result
))
def
test
(
self
):
if
not
hasattr
(
self
,
'test_dataloader'
):
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
...
...
@@ -210,5 +263,6 @@ class Trainer:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
self
.
logger
.
info
(
'laod model {} {} params!'
.
format
(
self
.
cfg
.
model
.
name
,
'net'
+
name
))
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
.
set_dict
(
state_dicts
[
'net'
+
name
])
ppgan/models/__init__.py
浏览文件 @
2aaef2e9
from
.base_model
import
BaseModel
from
.cycle_gan_model
import
CycleGANModel
from
.pix2pix_model
import
Pix2PixModel
from
.srgan_model
import
SRGANModel
from
.sr_model
import
SRModel
ppgan/models/generators/__init__.py
浏览文件 @
2aaef2e9
from
.resnet
import
ResnetGenerator
from
.unet
import
UnetGenerator
\ No newline at end of file
from
.unet
import
UnetGenerator
from
.rrdb_net
import
RRDBNet
\ No newline at end of file
ppgan/utils/visual.py
浏览文件 @
2aaef2e9
...
...
@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8):
image_numpy
=
image_numpy
[
0
]
if
image_numpy
.
shape
[
0
]
==
1
:
# grayscale to RGB
image_numpy
=
np
.
tile
(
image_numpy
,
(
3
,
1
,
1
))
image_numpy
=
(
np
.
transpose
(
image_numpy
,
(
1
,
2
,
0
))
+
1
)
/
2.0
*
255.0
# post-processing: tranpose and scaling
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
image_numpy
=
image_numpy
.
clip
(
0
,
1
)
image_numpy
=
(
np
.
transpose
(
image_numpy
,
(
1
,
2
,
0
)))
*
255.0
# post-processing: tranpose and scaling
else
:
# if it is a numpy array, do nothing
image_numpy
=
input_image
return
image_numpy
.
astype
(
imtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录