Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
ae28ccf8
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae28ccf8
编写于
7月 03, 2020
作者:
C
ceci3
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi-gpu
上级
27457a04
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
116 addition
and
45 deletion
+116
-45
demo/gan_compression/distillers/base_resnet_distiller.py
demo/gan_compression/distillers/base_resnet_distiller.py
+40
-14
demo/gan_compression/distillers/resnet_distiller.py
demo/gan_compression/distillers/resnet_distiller.py
+19
-13
demo/gan_compression/gan_compression.py
demo/gan_compression/gan_compression.py
+16
-5
demo/gan_compression/models/base_model.py
demo/gan_compression/models/base_model.py
+1
-1
demo/gan_compression/models/cycle_gan_model.py
demo/gan_compression/models/cycle_gan_model.py
+15
-5
demo/gan_compression/models/generator/resnet_generator.py
demo/gan_compression/models/generator/resnet_generator.py
+4
-1
demo/gan_compression/models/generator/sub_mobile_generator.py
.../gan_compression/models/generator/sub_mobile_generator.py
+1
-1
demo/gan_compression/models/generator/super_generator.py
demo/gan_compression/models/generator/super_generator.py
+3
-3
demo/gan_compression/models/network.py
demo/gan_compression/models/network.py
+11
-0
demo/gan_compression/supernets/resnet_supernet.py
demo/gan_compression/supernets/resnet_supernet.py
+6
-2
未找到文件。
demo/gan_compression/distillers/base_resnet_distiller.py
浏览文件 @
ae28ccf8
...
@@ -88,16 +88,32 @@ class BaseResnetDistiller(BaseModel):
...
@@ -88,16 +88,32 @@ class BaseResnetDistiller(BaseModel):
self
.
netG_pretrained
=
network
.
define_G
(
self
.
netG_pretrained
=
network
.
define_G
(
cfgs
.
input_nc
,
cfgs
.
output_nc
,
cfgs
.
pretrained_ngf
,
cfgs
.
input_nc
,
cfgs
.
output_nc
,
cfgs
.
pretrained_ngf
,
cfgs
.
pretrained_netG
,
cfgs
.
norm_type
,
0
)
cfgs
.
pretrained_netG
,
cfgs
.
norm_type
,
0
)
if
self
.
cfgs
.
use_parallel
:
self
.
netG_pretrained
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netG_pretrained
,
self
.
cfgs
.
strategy
)
self
.
netD
=
network
.
define_D
(
cfgs
.
output_nc
,
cfgs
.
ndf
,
cfgs
.
netD
,
self
.
netD
=
network
.
define_D
(
cfgs
.
output_nc
,
cfgs
.
ndf
,
cfgs
.
netD
,
cfgs
.
norm_type
,
cfgs
.
n_layer_D
)
cfgs
.
norm_type
,
cfgs
.
n_layer_D
)
if
self
.
cfgs
.
use_parallel
:
self
.
netG_teacher
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netG_teacher
,
self
.
cfgs
.
strategy
)
self
.
netG_student
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netG_student
,
self
.
cfgs
.
strategy
)
self
.
netD
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netD
,
self
.
cfgs
.
strategy
)
self
.
netG_teacher
.
eval
()
self
.
netG_teacher
.
eval
()
self
.
netG_student
.
train
()
self
.
netG_student
.
train
()
self
.
netD
.
train
()
self
.
netD
.
train
()
### [9, 12, 15, 18]
### [9, 12, 15, 18]
self
.
mapping_layers
=
[
'model.%d'
%
i
for
i
in
range
(
9
,
21
,
3
)]
self
.
mapping_layers
=
[
'_layers.model.%d'
%
i
for
i
in
range
(
9
,
21
,
3
)
]
if
self
.
cfgs
.
use_parallel
else
[
'model.%d'
%
i
for
i
in
range
(
9
,
21
,
3
)
]
self
.
netAs
=
[]
self
.
netAs
=
[]
self
.
Tacts
,
self
.
Sacts
=
{},
{}
self
.
Tacts
,
self
.
Sacts
=
{},
{}
...
@@ -157,8 +173,8 @@ class BaseResnetDistiller(BaseModel):
...
@@ -157,8 +173,8 @@ class BaseResnetDistiller(BaseModel):
self
.
is_best
=
False
self
.
is_best
=
False
def
setup
(
self
):
def
setup
(
self
,
model_weight
=
None
):
self
.
load_networks
()
self
.
load_networks
(
model_weight
)
if
self
.
cfgs
.
lambda_distill
>
0
:
if
self
.
cfgs
.
lambda_distill
>
0
:
...
@@ -183,30 +199,37 @@ class BaseResnetDistiller(BaseModel):
...
@@ -183,30 +199,37 @@ class BaseResnetDistiller(BaseModel):
def
set_single_input
(
self
,
inputs
):
def
set_single_input
(
self
,
inputs
):
self
.
real_A
=
inputs
[
0
]
self
.
real_A
=
inputs
[
0
]
def
load_networks
(
self
):
def
load_networks
(
self
,
model_weight
=
None
):
if
self
.
cfgs
.
restore_teacher_G_path
is
None
:
if
self
.
cfgs
.
restore_teacher_G_path
is
None
:
assert
len
(
model_weight
)
!=
0
,
"restore_teacher_G_path and model_weight cannot be None at the same time."
if
self
.
cfgs
.
direction
==
'AtoB'
:
if
self
.
cfgs
.
direction
==
'AtoB'
:
teacher_G_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'mobile'
,
key
=
'netG_A'
if
'netG_A'
in
model_weight
else
'netG_teacher'
'last_netG_A'
)
self
.
netG_teacher
.
set_dict
(
model_weight
[
key
]
)
else
:
else
:
teacher_G_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'mobile'
,
key
=
'netG_B'
if
'netG_B'
in
model_weight
else
'netG_teacher'
'last_netG_B'
)
self
.
netG_teacher
.
set_dict
(
model_weight
[
key
]
)
else
:
else
:
teacher_G_path
=
self
.
cfgs
.
restore_teacher_G_path
util
.
load_network
(
self
.
netG_teacher
,
self
.
cfgs
.
teacher_G_path
)
util
.
load_network
(
self
.
netG_teacher
,
teacher_G_path
)
if
self
.
cfgs
.
restore_student_G_path
is
not
None
:
if
self
.
cfgs
.
restore_student_G_path
is
not
None
:
util
.
load_network
(
self
.
netG_student
,
util
.
load_network
(
self
.
netG_student
,
self
.
cfgs
.
restore_student_G_path
)
self
.
cfgs
.
restore_student_G_path
)
else
:
else
:
if
self
.
task
==
'supernet'
:
if
self
.
task
==
'supernet'
:
student_G_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'distiller'
,
self
.
netG_student
.
set_dict
(
model_weight
[
'netG_student'
])
'last_stu_netG'
)
util
.
load_network
(
self
.
netG_student
,
student_G_path
)
if
self
.
cfgs
.
restore_D_path
is
not
None
:
if
self
.
cfgs
.
restore_D_path
is
not
None
:
util
.
load_network
(
self
.
netD
,
self
.
cfgs
.
restore_D_path
)
util
.
load_network
(
self
.
netD
,
self
.
cfgs
.
restore_D_path
)
else
:
if
self
.
cfgs
.
direction
==
'AtoB'
:
key
=
'netD_A'
if
'netD_A'
in
model_weight
else
'netD'
self
.
netD
.
set_dict
(
model_weight
[
key
])
else
:
key
=
'netD_B'
if
'netD_B'
in
model_weight
else
'netD'
self
.
netD
.
set_dict
(
model_weight
[
key
])
if
self
.
cfgs
.
restore_A_path
is
not
None
:
if
self
.
cfgs
.
restore_A_path
is
not
None
:
for
i
,
netA
in
enumerate
(
self
.
netAs
):
for
i
,
netA
in
enumerate
(
self
.
netAs
):
netA_path
=
'%s-%d.pth'
%
(
self
.
cfgs
.
restore_A_path
,
i
)
netA_path
=
'%s-%d.pth'
%
(
self
.
cfgs
.
restore_A_path
,
i
)
...
@@ -232,6 +255,9 @@ class BaseResnetDistiller(BaseModel):
...
@@ -232,6 +255,9 @@ class BaseResnetDistiller(BaseModel):
self
.
loss_D
=
(
self
.
loss_D_fake
+
self
.
loss_D_real
)
*
0.5
self
.
loss_D
=
(
self
.
loss_D_fake
+
self
.
loss_D_real
)
*
0.5
self
.
loss_D
.
backward
()
self
.
loss_D
.
backward
()
if
self
.
cfgs
.
use_parallel
:
self
.
netD
.
apply_collective_grads
()
def
calc_distill_loss
(
self
):
def
calc_distill_loss
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
...
demo/gan_compression/distillers/resnet_distiller.py
浏览文件 @
ae28ccf8
...
@@ -117,6 +117,9 @@ class ResnetDistiller(BaseResnetDistiller):
...
@@ -117,6 +117,9 @@ class ResnetDistiller(BaseResnetDistiller):
self
.
loss_G
=
self
.
loss_G_gan
+
self
.
loss_G_recon
+
self
.
loss_G_distill
self
.
loss_G
=
self
.
loss_G_gan
+
self
.
loss_G_recon
+
self
.
loss_G_distill
self
.
loss_G
.
backward
()
self
.
loss_G
.
backward
()
if
self
.
cfgs
.
use_parallel
:
self
.
netG_student
.
apply_collective_grads
()
def
optimize_parameter
(
self
):
def
optimize_parameter
(
self
):
self
.
forward
()
self
.
forward
()
...
@@ -130,24 +133,27 @@ class ResnetDistiller(BaseResnetDistiller):
...
@@ -130,24 +133,27 @@ class ResnetDistiller(BaseResnetDistiller):
self
.
optimizer_G
.
optimizer
.
minimize
(
self
.
loss_G
)
self
.
optimizer_G
.
optimizer
.
minimize
(
self
.
loss_G
)
self
.
optimizer_G
.
optimizer
.
clear_gradients
()
self
.
optimizer_G
.
optimizer
.
clear_gradients
()
def
load_networks
(
self
):
def
load_networks
(
self
,
model_weight
=
None
):
load_pretrain
=
True
if
self
.
cfgs
.
restore_pretrained_G_path
!=
False
:
if
self
.
cfgs
.
restore_pretrained_G_path
is
not
None
:
if
self
.
cfgs
.
restore_pretrained_G_path
!=
None
:
pretrained_G_path
=
self
.
cfgs
.
restore_pretrained_G_path
pretrained_G_path
=
self
.
cfgs
.
restore_pretrained_G_path
else
:
util
.
load_network
(
self
.
netG_pretrained
,
pretrained_G_path
)
pretrained_G_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'mobile'
,
else
:
'last_netG_B'
)
assert
len
(
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
pretrained_G_path
,
'pdparams'
)):
model_weight
load_pretrain
=
False
)
!=
0
,
"restore_pretrained_G_path and model_weight can not be None at the same time, if you donnot want to load pretrained model, please set restore_pretrained_G_path=Fasle"
if
self
.
cfgs
.
direction
==
'AtoB'
:
self
.
netG_pretrained
.
set_dict
(
model_weight
[
'netG_A'
])
else
:
self
.
netG_pretrained
.
set_dict
(
model_weight
[
'netG_B'
])
if
load_pretrain
:
util
.
load_network
(
self
.
netG_pretrained
,
pretrained_G_path
)
load_pretrained_weight
(
load_pretrained_weight
(
self
.
cfgs
.
pretrained_netG
,
self
.
cfgs
.
student_netG
,
self
.
cfgs
.
pretrained_netG
,
self
.
cfgs
.
distiller_
student_netG
,
self
.
netG_pretrained
,
self
.
netG_student
,
self
.
netG_pretrained
,
self
.
netG_student
,
self
.
cfgs
.
pretrained_ngf
,
self
.
cfgs
.
student_ngf
)
self
.
cfgs
.
pretrained_ngf
,
self
.
cfgs
.
student_ngf
)
del
self
.
netG_pretrained
del
self
.
netG_pretrained
super
(
ResnetDistiller
,
self
).
load_networks
()
super
(
ResnetDistiller
,
self
).
load_networks
(
model_weight
)
def
evaluate_model
(
self
,
step
):
def
evaluate_model
(
self
,
step
):
ret
=
{}
ret
=
{}
...
...
demo/gan_compression/gan_compression.py
浏览文件 @
ae28ccf8
...
@@ -53,7 +53,8 @@ class gan_compression:
...
@@ -53,7 +53,8 @@ class gan_compression:
def
start_train
(
self
):
def
start_train
(
self
):
steps
=
self
.
cfgs
.
task
.
split
(
'+'
)
steps
=
self
.
cfgs
.
task
.
split
(
'+'
)
for
step
in
steps
:
model_weight
=
{}
for
idx
,
step
in
enumerate
(
steps
):
if
step
==
'mobile'
:
if
step
==
'mobile'
:
from
models
import
create_model
from
models
import
create_model
elif
step
==
'distiller'
:
elif
step
==
'distiller'
:
...
@@ -66,9 +67,16 @@ class gan_compression:
...
@@ -66,9 +67,16 @@ class gan_compression:
print
(
print
(
"============================= start train {} =============================="
.
"============================= start train {} =============================="
.
format
(
step
))
format
(
step
))
fluid
.
enable_imperative
()
fluid
.
enable_imperative
(
place
=
self
.
cfgs
.
place
)
if
self
.
cfgs
.
use_parallel
and
idx
==
0
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
setattr
(
self
.
cfgs
,
'strategy'
,
strategy
)
model
=
create_model
(
self
.
cfgs
)
model
=
create_model
(
self
.
cfgs
)
model
.
setup
()
model
.
setup
(
model_weight
)
### clear model_weight every step
model_weight
=
{}
_train_dataloader
,
_
=
create_data
(
self
.
cfgs
)
_train_dataloader
,
_
=
create_data
(
self
.
cfgs
)
...
@@ -90,6 +98,11 @@ class gan_compression:
...
@@ -90,6 +98,11 @@ class gan_compression:
message
+=
'%s: %.3f '
%
(
k
,
v
)
message
+=
'%s: %.3f '
%
(
k
,
v
)
logging
.
info
(
message
)
logging
.
info
(
message
)
if
epoch_id
==
(
epochs
-
1
):
for
name
in
model
.
model_names
:
model_weight
[
name
]
=
model
.
_sub_layers
[
name
].
state_dict
()
save_model
=
(
not
self
.
cfgs
.
use_parallel
)
or
(
save_model
=
(
not
self
.
cfgs
.
use_parallel
)
or
(
self
.
cfgs
.
use_parallel
and
self
.
cfgs
.
use_parallel
and
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
...
@@ -97,8 +110,6 @@ class gan_compression:
...
@@ -97,8 +110,6 @@ class gan_compression:
epochs
-
1
)
and
save_model
:
epochs
-
1
)
and
save_model
:
model
.
evaluate_model
(
epoch_id
)
model
.
evaluate_model
(
epoch_id
)
model
.
save_network
(
epoch_id
)
model
.
save_network
(
epoch_id
)
if
epoch_id
==
(
epochs
-
1
):
model
.
save_network
(
'last'
)
print
(
"="
*
80
)
print
(
"="
*
80
)
...
...
demo/gan_compression/models/base_model.py
浏览文件 @
ae28ccf8
...
@@ -24,7 +24,7 @@ class BaseModel(fluid.dygraph.Layer):
...
@@ -24,7 +24,7 @@ class BaseModel(fluid.dygraph.Layer):
def
set_input
(
self
,
inputs
):
def
set_input
(
self
,
inputs
):
raise
NotImplementedError
raise
NotImplementedError
def
setup
(
self
):
def
setup
(
self
,
model_weight
=
None
):
self
.
load_network
()
self
.
load_network
()
def
load_network
(
self
):
def
load_network
(
self
):
...
...
demo/gan_compression/models/cycle_gan_model.py
浏览文件 @
ae28ccf8
...
@@ -94,7 +94,7 @@ class CycleGAN(BaseModel):
...
@@ -94,7 +94,7 @@ class CycleGAN(BaseModel):
'D_A'
,
'G_A'
,
'G_cycle_A'
,
'G_idt_A'
,
'D_B'
,
'G_B'
,
'G_cycle_B'
,
'D_A'
,
'G_A'
,
'G_cycle_A'
,
'G_idt_A'
,
'D_B'
,
'G_B'
,
'G_cycle_B'
,
'G_idt_B'
'G_idt_B'
]
]
self
.
model_names
=
[
'
G_A'
,
'G_B'
,
'D_A'
,
'
D_B'
]
self
.
model_names
=
[
'
netG_A'
,
'netG_B'
,
'netD_A'
,
'net
D_B'
]
self
.
netG_A
=
network
.
define_G
(
cfgs
.
input_nc
,
cfgs
.
output_nc
,
cfgs
.
ngf
,
self
.
netG_A
=
network
.
define_G
(
cfgs
.
input_nc
,
cfgs
.
output_nc
,
cfgs
.
ngf
,
cfgs
.
netG
,
cfgs
.
norm_type
,
cfgs
.
netG
,
cfgs
.
norm_type
,
...
@@ -107,6 +107,16 @@ class CycleGAN(BaseModel):
...
@@ -107,6 +107,16 @@ class CycleGAN(BaseModel):
self
.
netD_B
=
network
.
define_D
(
cfgs
.
input_nc
,
cfgs
.
ndf
,
cfgs
.
netD
,
self
.
netD_B
=
network
.
define_D
(
cfgs
.
input_nc
,
cfgs
.
ndf
,
cfgs
.
netD
,
cfgs
.
norm_type
,
cfgs
.
n_layer_D
)
cfgs
.
norm_type
,
cfgs
.
n_layer_D
)
if
self
.
cfgs
.
use_parallel
:
self
.
netG_A
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netG_A
,
self
.
cfgs
.
strategy
)
self
.
netG_B
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netG_B
,
self
.
cfgs
.
strategy
)
self
.
netD_A
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netD_A
,
self
.
cfgs
.
strategy
)
self
.
netD_B
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
netD_B
,
self
.
cfgs
.
strategy
)
if
cfgs
.
lambda_identity
>
0.0
:
if
cfgs
.
lambda_identity
>
0.0
:
assert
(
cfgs
.
input_nc
==
cfgs
.
output_nc
)
assert
(
cfgs
.
input_nc
==
cfgs
.
output_nc
)
self
.
fake_A_pool
=
ImagePool
(
cfgs
.
pool_size
)
self
.
fake_A_pool
=
ImagePool
(
cfgs
.
pool_size
)
...
@@ -159,12 +169,12 @@ class CycleGAN(BaseModel):
...
@@ -159,12 +169,12 @@ class CycleGAN(BaseModel):
def
set_single_input
(
self
,
inputs
):
def
set_single_input
(
self
,
inputs
):
self
.
real_A
=
inputs
[
0
]
self
.
real_A
=
inputs
[
0
]
def
setup
(
self
):
def
setup
(
self
,
model_weight
=
None
):
self
.
load_network
()
self
.
load_network
()
def
load_network
(
self
):
def
load_network
(
self
):
for
name
in
self
.
model_names
:
for
name
in
self
.
model_names
:
net
=
getattr
(
self
,
'net'
+
name
,
None
)
net
=
getattr
(
self
,
name
,
None
)
path
=
getattr
(
self
.
cfgs
,
'restore_%s_path'
%
name
,
None
)
path
=
getattr
(
self
.
cfgs
,
'restore_%s_path'
%
name
,
None
)
if
path
is
not
None
:
if
path
is
not
None
:
util
.
load_network
(
net
,
path
)
util
.
load_network
(
net
,
path
)
...
@@ -172,10 +182,10 @@ class CycleGAN(BaseModel):
...
@@ -172,10 +182,10 @@ class CycleGAN(BaseModel):
def
save_network
(
self
,
epoch
):
def
save_network
(
self
,
epoch
):
for
name
in
self
.
model_names
:
for
name
in
self
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
save_filename
=
'%s_
net
%s'
%
(
epoch
,
name
)
save_filename
=
'%s_%s'
%
(
epoch
,
name
)
save_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'mobile'
,
save_path
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'mobile'
,
save_filename
)
save_filename
)
net
=
getattr
(
self
,
'net'
+
name
)
net
=
getattr
(
self
,
name
)
fluid
.
save_dygraph
(
net
.
state_dict
(),
save_path
)
fluid
.
save_dygraph
(
net
.
state_dict
(),
save_path
)
def
forward
(
self
):
def
forward
(
self
):
...
...
demo/gan_compression/models/generator/resnet_generator.py
浏览文件 @
ae28ccf8
import
functools
import
functools
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.tensor
as
tensor
from
paddle.fluid.dygraph.nn
import
InstanceNorm
,
Conv2D
,
Conv2DTranspose
from
paddle.fluid.dygraph.nn
import
InstanceNorm
,
Conv2D
,
Conv2DTranspose
from
paddle.nn.layer
import
ReLU
,
Pad2D
from
paddle.nn.layer
import
ReLU
,
Pad2D
from
paddleslim.models.dygraph.modules
import
ResnetBlock
from
paddleslim.models.dygraph.modules
import
ResnetBlock
...
@@ -56,6 +57,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
...
@@ -56,6 +57,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
for
i
in
range
(
n_downsampling
):
for
i
in
range
(
n_downsampling
):
mult
=
2
**
(
n_downsampling
-
i
)
mult
=
2
**
(
n_downsampling
-
i
)
output_size
=
(
i
+
1
)
*
(
self
.
cfgs
.
crop_size
/
2
)
self
.
model
.
extend
([
self
.
model
.
extend
([
Conv2DTranspose
(
Conv2DTranspose
(
ngf
*
mult
,
ngf
*
mult
,
...
@@ -63,6 +65,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
...
@@ -63,6 +65,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
,
stride
=
2
,
padding
=
1
,
padding
=
1
,
output_size
=
output_size
,
bias_attr
=
use_bias
),
Pad2D
(
bias_attr
=
use_bias
),
Pad2D
(
paddings
=
[
0
,
1
,
0
,
1
],
mode
=
'constant'
,
pad_value
=
0.0
),
paddings
=
[
0
,
1
,
0
,
1
],
mode
=
'constant'
,
pad_value
=
0.0
),
norm_layer
(
int
(
ngf
*
mult
/
2
)),
ReLU
()
norm_layer
(
int
(
ngf
*
mult
/
2
)),
ReLU
()
...
@@ -72,7 +75,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
...
@@ -72,7 +75,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
self
.
model
.
extend
([
Conv2D
(
ngf
,
output_nc
,
filter_size
=
7
,
padding
=
0
)])
self
.
model
.
extend
([
Conv2D
(
ngf
,
output_nc
,
filter_size
=
7
,
padding
=
0
)])
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
y
=
fluid
.
layers
.
clamp
(
inputs
,
min
=-
1.0
,
max
=
1.0
)
y
=
tensor
.
clamp
(
inputs
,
min
=-
1.0
,
max
=
1.0
)
for
sublayer
in
self
.
model
:
for
sublayer
in
self
.
model
:
y
=
sublayer
(
y
)
y
=
sublayer
(
y
)
y
=
fluid
.
layers
.
tanh
(
y
)
y
=
fluid
.
layers
.
tanh
(
y
)
...
...
demo/gan_compression/models/generator/sub_mobile_generator.py
浏览文件 @
ae28ccf8
...
@@ -75,7 +75,7 @@ class SubMobileResnetGenerator(fluid.dygraph.Layer):
...
@@ -75,7 +75,7 @@ class SubMobileResnetGenerator(fluid.dygraph.Layer):
for
i
in
range
(
n_downsampling
):
for
i
in
range
(
n_downsampling
):
out_c
=
config
[
'channels'
][
offset
+
i
]
out_c
=
config
[
'channels'
][
offset
+
i
]
mult
=
2
**
(
n_downsampling
-
i
)
mult
=
2
**
(
n_downsampling
-
i
)
output_size
=
(
i
+
1
)
*
128
output_size
=
(
i
+
1
)
*
(
self
.
cfgs
.
crop_size
/
2
)
self
.
model
.
extend
([
self
.
model
.
extend
([
Conv2DTranspose
(
Conv2DTranspose
(
in_c
*
mult
,
in_c
*
mult
,
...
...
demo/gan_compression/models/generator/super_generator.py
浏览文件 @
ae28ccf8
...
@@ -75,19 +75,19 @@ class SuperMobileResnetGenerator(fluid.dygraph.Layer):
...
@@ -75,19 +75,19 @@ class SuperMobileResnetGenerator(fluid.dygraph.Layer):
input_channel
,
input_channel
,
output_nc
,
output_nc
,
ngf
,
ngf
,
norm_layer
=
Batch
Norm
,
norm_layer
=
Instance
Norm
,
dropout_rate
=
0
,
dropout_rate
=
0
,
n_blocks
=
6
,
n_blocks
=
6
,
padding_type
=
'reflect'
):
padding_type
=
'reflect'
):
assert
n_blocks
>=
0
assert
n_blocks
>=
0
super
(
SuperMobileResnetGenerator
,
self
).
__init__
()
super
(
SuperMobileResnetGenerator
,
self
).
__init__
()
use_bias
=
norm_layer
==
InstanceNorm
if
norm_layer
.
func
==
InstanceNorm
or
norm_layer
==
InstanceNorm
:
if
norm_layer
.
func
==
InstanceNorm
or
norm_layer
==
InstanceNorm
:
norm_layer
=
SuperInstanceNorm
norm_layer
=
SuperInstanceNorm
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
use_bias
=
norm_layer
==
InstanceNorm
self
.
model
=
fluid
.
dygraph
.
LayerList
([])
self
.
model
=
fluid
.
dygraph
.
LayerList
([])
self
.
model
.
extend
([
self
.
model
.
extend
([
Pad2D
(
Pad2D
(
...
...
demo/gan_compression/models/network.py
浏览文件 @
ae28ccf8
...
@@ -18,6 +18,7 @@ from discrimitor import NLayerDiscriminator
...
@@ -18,6 +18,7 @@ from discrimitor import NLayerDiscriminator
from
generator.resnet_generator
import
ResnetGenerator
from
generator.resnet_generator
import
ResnetGenerator
from
generator.mobile_generator
import
MobileResnetGenerator
from
generator.mobile_generator
import
MobileResnetGenerator
from
generator.super_generator
import
SuperMobileResnetGenerator
from
generator.super_generator
import
SuperMobileResnetGenerator
from
generator.sub_mobile_generator
import
SubMobileResnetGenerator
class
Identity
(
fluid
.
dygraph
.
Layer
):
class
Identity
(
fluid
.
dygraph
.
Layer
):
...
@@ -88,6 +89,16 @@ def define_G(input_nc,
...
@@ -88,6 +89,16 @@ def define_G(input_nc,
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
n_blocks
=
9
)
n_blocks
=
9
)
elif
netG
==
'sub_mobile_resnet_9blocks'
:
assert
self
.
cfgs
.
config_str
is
not
None
config
=
decode_config
(
self
.
cfgs
.
config_str
)
net
=
SubMobileResnetGenerator
(
input_nc
,
output_nc
,
config
,
norm_layer
=
norm_layer
,
dropout_rate
=
dropout_rate
,
n_blocks
=
9
)
return
net
return
net
...
...
demo/gan_compression/supernets/resnet_supernet.py
浏览文件 @
ae28ccf8
...
@@ -89,7 +89,10 @@ class ResnetSupernet(BaseResnetDistiller):
...
@@ -89,7 +89,10 @@ class ResnetSupernet(BaseResnetDistiller):
with
fluid
.
dygraph
.
no_grad
():
with
fluid
.
dygraph
.
no_grad
():
self
.
Tfake_B
=
self
.
netG_teacher
(
self
.
real_A
)
self
.
Tfake_B
=
self
.
netG_teacher
(
self
.
real_A
)
self
.
Tfake_B
.
stop_gradient
=
True
self
.
Tfake_B
.
stop_gradient
=
True
self
.
netG_student
.
configs
=
config
if
self
.
cfgs
.
use_parallel
:
self
.
netG_student
.
_layers
.
configs
=
config
else
:
self
.
netG_student
.
configs
=
config
self
.
Sfake_B
=
self
.
netG_student
(
self
.
real_A
)
self
.
Sfake_B
=
self
.
netG_student
(
self
.
real_A
)
def
calc_distill_loss
(
self
):
def
calc_distill_loss
(
self
):
...
@@ -137,7 +140,8 @@ class ResnetSupernet(BaseResnetDistiller):
...
@@ -137,7 +140,8 @@ class ResnetSupernet(BaseResnetDistiller):
def
evaluate_model
(
self
,
step
):
def
evaluate_model
(
self
,
step
):
ret
=
{}
ret
=
{}
self
.
is_best
=
False
self
.
is_best
=
False
save_dir
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'eval'
,
str
(
step
))
save_dir
=
os
.
path
.
join
(
self
.
cfgs
.
save_dir
,
'supernet'
,
'eval'
,
str
(
step
))
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
self
.
netG_student
.
eval
()
self
.
netG_student
.
eval
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录