Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
96551765
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
96551765
编写于
10月 21, 2022
作者:
B
Birdylx
提交者:
GitHub
10月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support amp for esrgan (#712)
上级
af681065
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
124 addition
and
15 deletion
+124
-15
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+11
-6
ppgan/models/edvr_model.py
ppgan/models/edvr_model.py
+3
-3
ppgan/models/esrgan_model.py
ppgan/models/esrgan_model.py
+85
-0
ppgan/models/msvsr_model.py
ppgan/models/msvsr_model.py
+3
-3
ppgan/models/sr_model.py
ppgan/models/sr_model.py
+18
-0
test_tipc/configs/edvr/train_infer_python.txt
test_tipc/configs/edvr/train_infer_python.txt
+1
-1
test_tipc/configs/esrgan/train_infer_python.txt
test_tipc/configs/esrgan/train_infer_python.txt
+1
-1
test_tipc/configs/msvsr/train_infer_python.txt
test_tipc/configs/msvsr/train_infer_python.txt
+1
-1
test_tipc/prepare.sh
test_tipc/prepare.sh
+1
-0
未找到文件。
ppgan/engine/trainer.py
浏览文件 @
96551765
...
@@ -133,7 +133,7 @@ class Trainer:
...
@@ -133,7 +133,7 @@ class Trainer:
cfg
.
optimizer
)
cfg
.
optimizer
)
# setup amp train
# setup amp train
self
.
scaler
=
self
.
setup_amp_train
()
if
self
.
cfg
.
amp
else
None
self
.
scaler
s
=
self
.
setup_amp_train
()
if
self
.
cfg
.
amp
else
None
# multiple gpus prepare
# multiple gpus prepare
if
ParallelEnv
().
nranks
>
1
:
if
ParallelEnv
().
nranks
>
1
:
...
@@ -164,11 +164,10 @@ class Trainer:
...
@@ -164,11 +164,10 @@ class Trainer:
self
.
profiler_options
=
cfg
.
profiler_options
self
.
profiler_options
=
cfg
.
profiler_options
def
setup_amp_train
(
self
):
def
setup_amp_train
(
self
):
""" decerate model, optimizer and return a GradScaler """
""" decerate model, optimizer and return a list of GradScaler """
self
.
logger
.
info
(
'use AMP to train. AMP level = {}'
.
format
(
self
.
logger
.
info
(
'use AMP to train. AMP level = {}'
.
format
(
self
.
cfg
.
amp_level
))
self
.
cfg
.
amp_level
))
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
# need to decorate model and optim if amp_level == 'O2'
# need to decorate model and optim if amp_level == 'O2'
if
self
.
cfg
.
amp_level
==
'O2'
:
if
self
.
cfg
.
amp_level
==
'O2'
:
nets
,
optimizers
=
list
(
self
.
model
.
nets
.
values
()),
list
(
nets
,
optimizers
=
list
(
self
.
model
.
nets
.
values
()),
list
(
...
@@ -181,7 +180,13 @@ class Trainer:
...
@@ -181,7 +180,13 @@ class Trainer:
self
.
model
.
nets
[
k
]
=
nets
[
i
]
self
.
model
.
nets
[
k
]
=
nets
[
i
]
for
i
,
(
k
,
_
)
in
enumerate
(
self
.
optimizers
.
items
()):
for
i
,
(
k
,
_
)
in
enumerate
(
self
.
optimizers
.
items
()):
self
.
optimizers
[
k
]
=
optimizers
[
i
]
self
.
optimizers
[
k
]
=
optimizers
[
i
]
return
scaler
scalers
=
[
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
i
in
range
(
len
(
self
.
optimizers
))
]
return
scalers
def
distributed_data_parallel
(
self
):
def
distributed_data_parallel
(
self
):
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
...
@@ -223,7 +228,7 @@ class Trainer:
...
@@ -223,7 +228,7 @@ class Trainer:
self
.
model
.
setup_input
(
data
)
self
.
model
.
setup_input
(
data
)
if
self
.
cfg
.
amp
:
if
self
.
cfg
.
amp
:
self
.
model
.
train_iter_amp
(
self
.
optimizers
,
self
.
scaler
,
self
.
model
.
train_iter_amp
(
self
.
optimizers
,
self
.
scaler
s
,
self
.
cfg
.
amp_level
)
# amp train
self
.
cfg
.
amp_level
)
# amp train
else
:
else
:
self
.
model
.
train_iter
(
self
.
optimizers
)
# norm train
self
.
model
.
train_iter
(
self
.
optimizers
)
# norm train
...
...
ppgan/models/edvr_model.py
浏览文件 @
96551765
...
@@ -76,7 +76,7 @@ class EDVRModel(BaseSRModel):
...
@@ -76,7 +76,7 @@ class EDVRModel(BaseSRModel):
self
.
current_iter
+=
1
self
.
current_iter
+=
1
# amp train with brute force implementation
# amp train with brute force implementation
def
train_iter_amp
(
self
,
optims
=
None
,
scaler
=
None
,
amp_level
=
'O1'
):
def
train_iter_amp
(
self
,
optims
=
None
,
scaler
s
=
None
,
amp_level
=
'O1'
):
optims
[
'optim'
].
clear_grad
()
optims
[
'optim'
].
clear_grad
()
if
self
.
tsa_iter
:
if
self
.
tsa_iter
:
if
self
.
current_iter
==
1
:
if
self
.
current_iter
==
1
:
...
@@ -97,9 +97,9 @@ class EDVRModel(BaseSRModel):
...
@@ -97,9 +97,9 @@ class EDVRModel(BaseSRModel):
loss_pixel
=
self
.
pixel_criterion
(
self
.
output
,
self
.
gt
)
loss_pixel
=
self
.
pixel_criterion
(
self
.
output
,
self
.
gt
)
self
.
losses
[
'loss_pixel'
]
=
loss_pixel
self
.
losses
[
'loss_pixel'
]
=
loss_pixel
scaled_loss
=
scaler
.
scale
(
loss_pixel
)
scaled_loss
=
scaler
s
[
0
]
.
scale
(
loss_pixel
)
scaled_loss
.
backward
()
scaled_loss
.
backward
()
scaler
.
minimize
(
optims
[
'optim'
],
scaled_loss
)
scaler
s
[
0
]
.
minimize
(
optims
[
'optim'
],
scaled_loss
)
self
.
current_iter
+=
1
self
.
current_iter
+=
1
...
...
ppgan/models/esrgan_model.py
浏览文件 @
96551765
...
@@ -29,6 +29,7 @@ class ESRGAN(BaseSRModel):
...
@@ -29,6 +29,7 @@ class ESRGAN(BaseSRModel):
ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
generator
,
generator
,
discriminator
=
None
,
discriminator
=
None
,
...
@@ -127,3 +128,87 @@ class ESRGAN(BaseSRModel):
...
@@ -127,3 +128,87 @@ class ESRGAN(BaseSRModel):
else
:
else
:
l_total
.
backward
()
l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
optimizers
[
'optimG'
].
step
()
# amp training
def
train_iter_amp
(
self
,
optimizers
=
None
,
scalers
=
None
,
amp_level
=
'O1'
):
optimizers
[
'optimG'
].
clear_grad
()
l_total
=
0
# put loss computation in amp context
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
amp_level
):
self
.
output
=
self
.
nets
[
'generator'
](
self
.
lq
)
self
.
visual_items
[
'output'
]
=
self
.
output
# pixel loss
if
self
.
pixel_criterion
:
l_pix
=
self
.
pixel_criterion
(
self
.
output
,
self
.
gt
)
l_total
+=
l_pix
self
.
losses
[
'loss_pix'
]
=
l_pix
if
self
.
perceptual_criterion
:
l_g_percep
,
l_g_style
=
self
.
perceptual_criterion
(
self
.
output
,
self
.
gt
)
# l_total += l_pix
if
l_g_percep
is
not
None
:
l_total
+=
l_g_percep
self
.
losses
[
'loss_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_total
+=
l_g_style
self
.
losses
[
'loss_style'
]
=
l_g_style
# gan loss (relativistic gan)
if
hasattr
(
self
,
'gan_criterion'
):
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
False
)
# put fwd and loss computation in amp context
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
amp_level
):
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
).
detach
()
fake_g_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
)
l_g_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_g_pred
),
False
,
is_disc
=
False
)
l_g_fake
=
self
.
gan_criterion
(
fake_g_pred
-
paddle
.
mean
(
real_d_pred
),
True
,
is_disc
=
False
)
l_g_gan
=
(
l_g_real
+
l_g_fake
)
/
2
l_total
+=
l_g_gan
self
.
losses
[
'l_g_gan'
]
=
l_g_gan
scaled_l_total
=
scalers
[
0
].
scale
(
l_total
)
scaled_l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
scalers
[
0
].
minimize
(
optimizers
[
'optimG'
],
scaled_l_total
)
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
True
)
optimizers
[
'optimD'
].
clear_grad
()
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
amp_level
):
# real
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
).
detach
()
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
)
l_d_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_d_pred
),
True
,
is_disc
=
True
)
*
0.5
# fake
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
.
detach
())
l_d_fake
=
self
.
gan_criterion
(
fake_d_pred
-
paddle
.
mean
(
real_d_pred
.
detach
()),
False
,
is_disc
=
True
)
*
0.5
l_temp
=
l_d_real
+
l_d_fake
scaled_l_temp
=
scalers
[
1
].
scale
(
l_temp
)
scaled_l_temp
.
backward
()
scalers
[
0
].
minimize
(
optimizers
[
'optimD'
],
scaled_l_temp
)
self
.
losses
[
'l_d_real'
]
=
l_d_real
self
.
losses
[
'l_d_fake'
]
=
l_d_fake
self
.
losses
[
'out_d_real'
]
=
paddle
.
mean
(
real_d_pred
.
detach
())
self
.
losses
[
'out_d_fake'
]
=
paddle
.
mean
(
fake_d_pred
.
detach
())
else
:
scaled_l_total
=
scalers
[
0
].
scale
(
l_total
)
scaled_l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
scalers
[
0
].
minimize
(
optimizers
[
'optimG'
],
scaled_l_total
)
ppgan/models/msvsr_model.py
浏览文件 @
96551765
...
@@ -98,7 +98,7 @@ class MultiStageVSRModel(BaseSRModel):
...
@@ -98,7 +98,7 @@ class MultiStageVSRModel(BaseSRModel):
self
.
current_iter
+=
1
self
.
current_iter
+=
1
# amp train with brute force implementation
# amp train with brute force implementation
def
train_iter_amp
(
self
,
optims
=
None
,
scaler
=
None
,
amp_level
=
'O1'
):
def
train_iter_amp
(
self
,
optims
=
None
,
scaler
s
=
None
,
amp_level
=
'O1'
):
optims
[
'optim'
].
clear_grad
()
optims
[
'optim'
].
clear_grad
()
if
self
.
fix_iter
:
if
self
.
fix_iter
:
if
self
.
current_iter
==
1
:
if
self
.
current_iter
==
1
:
...
@@ -133,9 +133,9 @@ class MultiStageVSRModel(BaseSRModel):
...
@@ -133,9 +133,9 @@ class MultiStageVSRModel(BaseSRModel):
if
'loss_pix'
in
_key
)
if
'loss_pix'
in
_key
)
self
.
losses
[
'loss'
]
=
self
.
loss
self
.
losses
[
'loss'
]
=
self
.
loss
scaled_loss
=
scaler
.
scale
(
self
.
loss
)
scaled_loss
=
scaler
s
[
0
]
.
scale
(
self
.
loss
)
scaled_loss
.
backward
()
scaled_loss
.
backward
()
scaler
.
minimize
(
optims
[
'optim'
],
scaled_loss
)
scaler
s
[
0
]
.
minimize
(
optims
[
'optim'
],
scaled_loss
)
self
.
current_iter
+=
1
self
.
current_iter
+=
1
...
...
ppgan/models/sr_model.py
浏览文件 @
96551765
...
@@ -27,6 +27,7 @@ from ..modules.init import reset_parameters
...
@@ -27,6 +27,7 @@ from ..modules.init import reset_parameters
class
BaseSRModel
(
BaseModel
):
class
BaseSRModel
(
BaseModel
):
"""Base SR model for single image super-resolution.
"""Base SR model for single image super-resolution.
"""
"""
def
__init__
(
self
,
generator
,
pixel_criterion
=
None
,
use_init_weight
=
False
):
def
__init__
(
self
,
generator
,
pixel_criterion
=
None
,
use_init_weight
=
False
):
"""
"""
Args:
Args:
...
@@ -65,6 +66,22 @@ class BaseSRModel(BaseModel):
...
@@ -65,6 +66,22 @@ class BaseSRModel(BaseModel):
loss_pixel
.
backward
()
loss_pixel
.
backward
()
optims
[
'optim'
].
step
()
optims
[
'optim'
].
step
()
# amp training
def
train_iter_amp
(
self
,
optims
=
None
,
scalers
=
None
,
amp_level
=
'O1'
):
optims
[
'optim'
].
clear_grad
()
# put fwd and loss computation in amp context
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
amp_level
):
self
.
output
=
self
.
nets
[
'generator'
](
self
.
lq
)
self
.
visual_items
[
'output'
]
=
self
.
output
# pixel loss
loss_pixel
=
self
.
pixel_criterion
(
self
.
output
,
self
.
gt
)
self
.
losses
[
'loss_pixel'
]
=
loss_pixel
scaled_loss_pixel
=
scalers
[
0
].
scale
(
loss_pixel
)
scaled_loss_pixel
.
backward
()
scalers
[
0
].
minimize
(
optims
[
'optim'
],
scaled_loss_pixel
)
def
test_iter
(
self
,
metrics
=
None
):
def
test_iter
(
self
,
metrics
=
None
):
self
.
nets
[
'generator'
].
eval
()
self
.
nets
[
'generator'
].
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
...
@@ -84,6 +101,7 @@ class BaseSRModel(BaseModel):
...
@@ -84,6 +101,7 @@ class BaseSRModel(BaseModel):
def
init_sr_weight
(
net
):
def
init_sr_weight
(
net
):
def
reset_func
(
m
):
def
reset_func
(
m
):
if
hasattr
(
m
,
'weight'
)
and
(
not
isinstance
(
if
hasattr
(
m
,
'weight'
)
and
(
not
isinstance
(
m
,
(
nn
.
BatchNorm
,
nn
.
BatchNorm2D
))):
m
,
(
nn
.
BatchNorm
,
nn
.
BatchNorm2D
))):
...
...
test_tipc/configs/edvr/train_infer_python.txt
浏览文件 @
96551765
...
@@ -51,7 +51,7 @@ null:null
...
@@ -51,7 +51,7 @@ null:null
null:null
null:null
===========================train_benchmark_params==========================
===========================train_benchmark_params==========================
batch_size:64
batch_size:64
fp_items:fp32
fp_items:fp32
|fp16
total_iters:100
total_iters:100
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
flags:FLAGS_cudnn_exhaustive_search=1
test_tipc/configs/esrgan/train_infer_python.txt
浏览文件 @
96551765
...
@@ -51,7 +51,7 @@ null:null
...
@@ -51,7 +51,7 @@ null:null
null:null
null:null
===========================train_benchmark_params==========================
===========================train_benchmark_params==========================
batch_size:32|64
batch_size:32|64
fp_items:fp32
fp_items:fp32
|fp16
total_iters:500
total_iters:500
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
flags:FLAGS_cudnn_exhaustive_search=1
test_tipc/configs/msvsr/train_infer_python.txt
浏览文件 @
96551765
...
@@ -51,7 +51,7 @@ null:null
...
@@ -51,7 +51,7 @@ null:null
null:null
null:null
===========================train_benchmark_params==========================
===========================train_benchmark_params==========================
batch_size:2|4
batch_size:2|4
fp_items:fp32
fp_items:fp32
|fp16
total_iters:60
total_iters:60
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
flags:FLAGS_cudnn_exhaustive_search=1
...
...
test_tipc/prepare.sh
浏览文件 @
96551765
...
@@ -197,5 +197,6 @@ elif [ ${MODE} = "cpp_infer" ]; then
...
@@ -197,5 +197,6 @@ elif [ ${MODE} = "cpp_infer" ]; then
rm
-rf
./inference/msvsr
*
rm
-rf
./inference/msvsr
*
wget
-nc
-P
./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar
--no-check-certificate
wget
-nc
-P
./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar
--no-check-certificate
cd
./inference
&&
tar
xf msvsr.tar
&&
cd
../
cd
./inference
&&
tar
xf msvsr.tar
&&
cd
../
wget
-nc
-P
./data https://paddlegan.bj.bcebos.com/datasets/low_res.mp4
--no-check-certificate
fi
fi
fi
fi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录