Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
3a33f1a4
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a33f1a4
编写于
10月 29, 2020
作者:
B
Bai Yifan
提交者:
GitHub
10月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
migrate distillation code to paddle2.0 (#486)
上级
9fe2d24b
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
160 addition
and
175 deletion
+160
-175
demo/distillation/distill.py
demo/distillation/distill.py
+53
-54
paddleslim/dist/single_distiller.py
paddleslim/dist/single_distiller.py
+29
-25
tests/test_fsp_loss.py
tests/test_fsp_loss.py
+17
-21
tests/test_l2_loss.py
tests/test_l2_loss.py
+15
-19
tests/test_loss.py
tests/test_loss.py
+20
-26
tests/test_merge.py
tests/test_merge.py
+10
-10
tests/test_soft_label_loss.py
tests/test_soft_label_loss.py
+16
-20
未找到文件。
demo/distillation/distill.py
浏览文件 @
3a33f1a4
...
@@ -10,7 +10,6 @@ import paddle
...
@@ -10,7 +10,6 @@ import paddle
import
argparse
import
argparse
import
functools
import
functools
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
)
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
)
import
models
import
models
from
utility
import
add_arguments
,
print_arguments
,
_download
,
_decompress
from
utility
import
add_arguments
,
print_arguments
,
_download
,
_decompress
...
@@ -47,34 +46,35 @@ model_list = [m for m in dir(models) if "__" not in m]
...
@@ -47,34 +46,35 @@ model_list = [m for m in dir(models) if "__" not in m]
def
piecewise_decay
(
args
):
def
piecewise_decay
(
args
):
if
args
.
use_gpu
:
if
args
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
devices_num
=
paddle
.
fluid
.
core
.
get_cuda_device_count
()
else
:
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
step
=
int
(
step
=
int
(
math
.
ceil
(
float
(
args
.
total_images
)
/
args
.
batch_size
)
/
devices_num
)
math
.
ceil
(
float
(
args
.
total_images
)
/
args
.
batch_size
)
/
devices_num
)
bd
=
[
step
*
e
for
e
in
args
.
step_epochs
]
bd
=
[
step
*
e
for
e
in
args
.
step_epochs
]
lr
=
[
args
.
lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
lr
=
[
args
.
lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
)
learning_rate
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
optimizer
=
fluid
.
optimizer
.
Momentum
(
boundaries
=
bd
,
values
=
lr
,
verbose
=
False
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
momentum
=
args
.
momentum_rate
,
momentum
=
args
.
momentum_rate
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
return
learning_rate
,
optimizer
return
learning_rate
,
optimizer
def
cosine_decay
(
args
):
def
cosine_decay
(
args
):
if
cfg
.
use_gpu
:
if
args
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
devices_num
=
paddle
.
fluid
.
core
.
get_cuda_device_count
()
else
:
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
step
=
int
(
step
=
int
(
math
.
ceil
(
float
(
args
.
total_images
)
/
args
.
batch_size
)
/
devices_num
)
math
.
ceil
(
float
(
args
.
total_images
)
/
args
.
batch_size
)
/
devices_num
)
learning_rate
=
fluid
.
layers
.
cosine_d
ecay
(
learning_rate
=
paddle
.
optimizer
.
lr
.
CosineAnnealingD
ecay
(
learning_rate
=
args
.
lr
,
step_each_epoch
=
step
,
epochs
=
args
.
num_epochs
)
learning_rate
=
args
.
lr
,
T_max
=
step
*
args
.
num_epochs
,
verbose
=
False
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
momentum
=
args
.
momentum_rate
,
momentum
=
args
.
momentum_rate
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
return
learning_rate
,
optimizer
return
learning_rate
,
optimizer
...
@@ -104,20 +104,21 @@ def compress(args):
...
@@ -104,20 +104,21 @@ def compress(args):
assert
args
.
model
in
model_list
,
"{} is not in lists: {}"
.
format
(
args
.
model
,
assert
args
.
model
in
model_list
,
"{} is not in lists: {}"
.
format
(
args
.
model
,
model_list
)
model_list
)
student_program
=
fluid
.
Program
()
student_program
=
paddle
.
static
.
Program
()
s_startup
=
fluid
.
Program
()
s_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
student_program
,
s_startup
):
with
paddle
.
static
.
program_guard
(
student_program
,
s_startup
):
with
fluid
.
unique_name
.
guard
():
with
paddle
.
fluid
.
unique_name
.
guard
():
image
=
fluid
.
layers
.
data
(
image
=
paddle
.
static
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
name
=
'image'
,
shape
=
[
None
]
+
image_shape
,
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
label
=
paddle
.
static
.
data
(
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
train_loader
=
paddle
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
feed_list
=
[
image
,
label
],
capacity
=
64
,
capacity
=
64
,
use_double_buffer
=
True
,
use_double_buffer
=
True
,
iterable
=
True
)
iterable
=
True
)
valid_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
valid_loader
=
paddle
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
feed_list
=
[
image
,
label
],
capacity
=
64
,
capacity
=
64
,
use_double_buffer
=
True
,
use_double_buffer
=
True
,
...
@@ -125,32 +126,34 @@ def compress(args):
...
@@ -125,32 +126,34 @@ def compress(args):
# model definition
# model definition
model
=
models
.
__dict__
[
args
.
model
]()
model
=
models
.
__dict__
[
args
.
model
]()
out
=
model
.
net
(
input
=
image
,
class_dim
=
class_dim
)
out
=
model
.
net
(
input
=
image
,
class_dim
=
class_dim
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
out
,
label
=
label
)
cost
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
input
=
out
,
label
=
label
)
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
avg_cost
=
paddle
.
mean
(
x
=
cost
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
train_reader
=
paddle
.
fluid
.
io
.
batch
(
train_reader
=
paddle
.
batch
(
train_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
train_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
val_reader
=
paddle
.
fluid
.
io
.
batch
(
val_reader
=
paddle
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
val_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
val_program
=
student_program
.
clone
(
for_test
=
True
)
val_program
=
student_program
.
clone
(
for_test
=
True
)
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
places
=
paddle
.
static
.
cuda_places
(
)
if
args
.
use_gpu
else
paddle
.
static
.
cpu_places
()
place
=
places
[
0
]
exe
=
paddle
.
static
.
Executor
(
place
)
train_loader
.
set_sample_list_generator
(
train_reader
,
places
)
train_loader
.
set_sample_list_generator
(
train_reader
,
places
)
valid_loader
.
set_sample_list_generator
(
val_reader
,
place
)
valid_loader
.
set_sample_list_generator
(
val_reader
,
place
)
teacher_model
=
models
.
__dict__
[
args
.
teacher_model
]()
teacher_model
=
models
.
__dict__
[
args
.
teacher_model
]()
# define teacher program
# define teacher program
teacher_program
=
fluid
.
Program
()
teacher_program
=
paddle
.
static
.
Program
()
t_startup
=
fluid
.
Program
()
t_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_program
,
t_startup
):
with
paddle
.
static
.
program_guard
(
teacher_program
,
t_startup
):
with
fluid
.
unique_name
.
guard
():
with
paddle
.
fluid
.
unique_name
.
guard
():
image
=
fluid
.
layers
.
data
(
image
=
paddle
.
static
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
name
=
'image'
,
shape
=
[
None
]
+
image_shape
,
dtype
=
'float32'
)
predict
=
teacher_model
.
net
(
image
,
class_dim
=
class_dim
)
predict
=
teacher_model
.
net
(
image
,
class_dim
=
class_dim
)
exe
.
run
(
t_startup
)
exe
.
run
(
t_startup
)
...
@@ -171,40 +174,36 @@ def compress(args):
...
@@ -171,40 +174,36 @@ def compress(args):
exist
=
False
exist
=
False
return
exist
return
exist
fluid
.
io
.
load_vars
(
paddle
.
static
.
load
(
teacher_program
,
args
.
teacher_pretrained_model
,
exe
)
exe
,
args
.
teacher_pretrained_model
,
main_program
=
teacher_program
,
predicate
=
if_exist
)
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_program
,
student_program
,
data_name_map
,
place
)
merge
(
teacher_program
,
student_program
,
data_name_map
,
place
)
with
fluid
.
program_guard
(
student_program
,
s_startup
):
with
paddle
.
static
.
program_guard
(
student_program
,
s_startup
):
distill_loss
=
soft_label_loss
(
"teacher_fc_0.tmp_0"
,
"fc_0.tmp_0"
,
distill_loss
=
soft_label_loss
(
"teacher_fc_0.tmp_0"
,
"fc_0.tmp_0"
,
student_program
)
student_program
)
loss
=
avg_cost
+
distill_loss
loss
=
avg_cost
+
distill_loss
lr
,
opt
=
create_optimizer
(
args
)
lr
,
opt
=
create_optimizer
(
args
)
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
exe
.
run
(
s_startup
)
exe
.
run
(
s_startup
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
paddle
.
static
.
BuildStrategy
()
build_strategy
.
fuse_all_reduce_ops
=
False
build_strategy
.
fuse_all_reduce_ops
=
False
parallel_main
=
fluid
.
CompiledProgram
(
student_program
).
with_data_parallel
(
parallel_main
=
paddle
.
static
.
CompiledProgram
(
student_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
for
epoch_id
in
range
(
args
.
num_epochs
):
for
epoch_id
in
range
(
args
.
num_epochs
):
for
step_id
,
data
in
enumerate
(
train_loader
):
for
step_id
,
data
in
enumerate
(
train_loader
):
l
r_np
,
l
oss_1
,
loss_2
,
loss_3
=
exe
.
run
(
loss_1
,
loss_2
,
loss_3
=
exe
.
run
(
parallel_main
,
parallel_main
,
feed
=
data
,
feed
=
data
,
fetch_list
=
[
fetch_list
=
[
loss
.
name
,
avg_cost
.
name
,
distill_loss
.
name
])
lr
.
name
,
loss
.
name
,
avg_cost
.
name
,
distill_loss
.
name
])
if
step_id
%
args
.
log_period
==
0
:
if
step_id
%
args
.
log_period
==
0
:
_logger
.
info
(
_logger
.
info
(
"train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}"
.
"train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}"
.
format
(
epoch_id
,
step_id
,
lr_np
[
0
],
loss_1
[
0
],
loss_2
[
0
],
format
(
epoch_id
,
step_id
,
loss_3
[
0
]))
lr
.
get_lr
(),
loss_1
[
0
],
loss_2
[
0
],
loss_3
[
0
]))
lr
.
step
()
val_acc1s
=
[]
val_acc1s
=
[]
val_acc5s
=
[]
val_acc5s
=
[]
for
step_id
,
data
in
enumerate
(
valid_loader
):
for
step_id
,
data
in
enumerate
(
valid_loader
):
...
@@ -220,7 +219,7 @@ def compress(args):
...
@@ -220,7 +219,7 @@ def compress(args):
format
(
epoch_id
,
step_id
,
val_loss
[
0
],
val_acc1
[
0
],
format
(
epoch_id
,
step_id
,
val_loss
[
0
],
val_acc1
[
0
],
val_acc5
[
0
]))
val_acc5
[
0
]))
if
args
.
save_inference
:
if
args
.
save_inference
:
fluid
.
io
.
save_inference_model
(
paddle
.
static
.
save_inference_model
(
os
.
path
.
join
(
"./saved_models"
,
str
(
epoch_id
)),
[
"image"
],
os
.
path
.
join
(
"./saved_models"
,
str
(
epoch_id
)),
[
"image"
],
[
out
],
exe
,
student_program
)
[
out
],
exe
,
student_program
)
_logger
.
info
(
"epoch {} top1 {:.6f}, top5 {:.6f}"
.
format
(
_logger
.
info
(
"epoch {} top1 {:.6f}, top5 {:.6f}"
.
format
(
...
...
paddleslim/dist/single_distiller.py
浏览文件 @
3a33f1a4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
import
paddle
.fluid
as
fluid
import
paddle
def
merge
(
teacher_program
,
def
merge
(
teacher_program
,
...
@@ -32,7 +32,7 @@ def merge(teacher_program,
...
@@ -32,7 +32,7 @@ def merge(teacher_program,
input interface name, where key of dict is the
input interface name, where key of dict is the
input name of teacher_program, and value is the
input name of teacher_program, and value is the
input name of student_program.
input name of student_program.
place(
fluid.CPUPlace()|fluid.
CUDAPlace(N)): This parameter represents
place(
CPUPlace()|
CUDAPlace(N)): This parameter represents
paddle run on which device.
paddle run on which device.
scope(Scope): This parameter indicates the variable scope used by
scope(Scope): This parameter indicates the variable scope used by
the program. If not specified, the default global scope
the program. If not specified, the default global scope
...
@@ -43,8 +43,8 @@ def merge(teacher_program,
...
@@ -43,8 +43,8 @@ def merge(teacher_program,
Returns:
Returns:
None
None
"""
"""
if
scope
==
None
:
if
scope
==
None
:
scope
=
fluid
.
global_scope
()
scope
=
paddle
.
static
.
global_scope
()
teacher_program
=
teacher_program
.
clone
(
for_test
=
True
)
teacher_program
=
teacher_program
.
clone
(
for_test
=
True
)
for
teacher_var
in
teacher_program
.
list_vars
():
for
teacher_var
in
teacher_program
.
list_vars
():
skip_rename
=
False
skip_rename
=
False
...
@@ -117,22 +117,23 @@ def fsp_loss(teacher_var1_name,
...
@@ -117,22 +117,23 @@ def fsp_loss(teacher_var1_name,
Returns:
Returns:
Variable: fsp distiller loss.
Variable: fsp distiller loss.
"""
"""
if
program
==
None
:
if
program
==
None
:
program
=
fluid
.
default_main_program
()
program
=
paddle
.
static
.
default_main_program
()
teacher_var1
=
program
.
global_block
().
var
(
teacher_var1_name
)
teacher_var1
=
program
.
global_block
().
var
(
teacher_var1_name
)
teacher_var2
=
program
.
global_block
().
var
(
teacher_var2_name
)
teacher_var2
=
program
.
global_block
().
var
(
teacher_var2_name
)
student_var1
=
program
.
global_block
().
var
(
student_var1_name
)
student_var1
=
program
.
global_block
().
var
(
student_var1_name
)
student_var2
=
program
.
global_block
().
var
(
student_var2_name
)
student_var2
=
program
.
global_block
().
var
(
student_var2_name
)
teacher_fsp_matrix
=
fluid
.
layers
.
fsp_matrix
(
teacher_var1
,
teacher_var2
)
teacher_fsp_matrix
=
paddle
.
fluid
.
layers
.
fsp_matrix
(
teacher_var1
,
student_fsp_matrix
=
fluid
.
layers
.
fsp_matrix
(
student_var1
,
student_var2
)
teacher_var2
)
fsp_loss
=
fluid
.
layers
.
reduce_mean
(
student_fsp_matrix
=
paddle
.
fluid
.
layers
.
fsp_matrix
(
student_var1
,
fluid
.
layers
.
square
(
student_fsp_matrix
-
teacher_fsp_matrix
))
student_var2
)
fsp_loss
=
paddle
.
mean
(
paddle
.
nn
.
functional
.
square_error_cost
(
student_fsp_matrix
,
teacher_fsp_matrix
))
return
fsp_loss
return
fsp_loss
def
l2_loss
(
teacher_var_name
,
def
l2_loss
(
teacher_var_name
,
student_var_name
,
program
=
None
):
student_var_name
,
program
=
None
):
"""Combine variables from student model and teacher model by l2-loss.
"""Combine variables from student model and teacher model by l2-loss.
Args:
Args:
...
@@ -144,12 +145,12 @@ def l2_loss(teacher_var_name,
...
@@ -144,12 +145,12 @@ def l2_loss(teacher_var_name,
Returns:
Returns:
Variable: l2 distiller loss.
Variable: l2 distiller loss.
"""
"""
if
program
==
None
:
if
program
==
None
:
program
=
fluid
.
default_main_program
()
program
=
paddle
.
static
.
default_main_program
()
student_var
=
program
.
global_block
().
var
(
student_var_name
)
student_var
=
program
.
global_block
().
var
(
student_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
l2_loss
=
fluid
.
layers
.
reduce_
mean
(
l2_loss
=
paddle
.
mean
(
fluid
.
layers
.
square
(
student_var
-
teacher_var
))
paddle
.
nn
.
functional
.
square_error_cost
(
student_var
,
teacher_var
))
return
l2_loss
return
l2_loss
...
@@ -173,15 +174,18 @@ def soft_label_loss(teacher_var_name,
...
@@ -173,15 +174,18 @@ def soft_label_loss(teacher_var_name,
Returns:
Returns:
Variable: l2 distiller loss.
Variable: l2 distiller loss.
"""
"""
if
program
==
None
:
if
program
==
None
:
program
=
fluid
.
default_main_program
()
program
=
paddle
.
static
.
default_main_program
()
student_var
=
program
.
global_block
().
var
(
student_var_name
)
student_var
=
program
.
global_block
().
var
(
student_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
teacher_var
=
program
.
global_block
().
var
(
teacher_var_name
)
student_var
=
fluid
.
layers
.
softmax
(
student_var
/
student_temperature
)
teacher_var
=
fluid
.
layers
.
softmax
(
teacher_var
/
teacher_temperature
)
teacher_var
.
stop_gradient
=
True
teacher_var
.
stop_gradient
=
True
soft_label_loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
cross_entropy
(
student_var
=
paddle
.
nn
.
functional
.
softmax
(
student_var
/
student_temperature
)
teacher_var
=
paddle
.
nn
.
functional
.
softmax
(
teacher_var
/
teacher_temperature
)
soft_label_loss
=
paddle
.
mean
(
paddle
.
fluid
.
layers
.
cross_entropy
(
student_var
,
teacher_var
,
soft_label
=
True
))
student_var
,
teacher_var
,
soft_label
=
True
))
return
soft_label_loss
return
soft_label_loss
...
@@ -197,8 +201,8 @@ def loss(loss_func, program=None, **kwargs):
...
@@ -197,8 +201,8 @@ def loss(loss_func, program=None, **kwargs):
Returns:
Returns:
Variable: self defined distiller loss.
Variable: self defined distiller loss.
"""
"""
if
program
==
None
:
if
program
==
None
:
program
=
fluid
.
default_main_program
()
program
=
paddle
.
static
.
default_main_program
()
func_parameters
=
{}
func_parameters
=
{}
for
item
in
kwargs
.
items
():
for
item
in
kwargs
.
items
():
if
isinstance
(
item
[
1
],
str
):
if
isinstance
(
item
[
1
],
str
):
...
...
tests/test_fsp_loss.py
浏览文件 @
3a33f1a4
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
unittest
import
unittest
import
paddle
.fluid
as
fluid
import
paddle
from
paddleslim.dist
import
merge
,
fsp_loss
from
paddleslim.dist
import
merge
,
fsp_loss
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
from
static_case
import
StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
class
TestFSPLoss
(
StaticCase
):
class
TestFSPLoss
(
StaticCase
):
def
test_fsp_loss
(
self
):
def
test_fsp_loss
(
self
):
student_main
=
fluid
.
Program
()
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
student_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
student_main
,
student_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
conv1
+
conv2
teacher_main
=
fluid
.
Program
()
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
fluid
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_main
,
teacher_startup
):
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
sum1
=
conv1
+
conv2
...
@@ -43,20 +40,19 @@ class TestFSPLoss(StaticCase):
...
@@ -43,20 +40,19 @@ class TestFSPLoss(StaticCase):
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
place
=
fluid
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merge
(
teacher_main
,
paddle
.
static
.
default_main_program
(),
data_name_map
,
place
)
merged_ops
=
[]
merged_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
merged_ops
.
append
(
op
.
type
)
with
fluid
.
program_guard
(
student_main
):
distill_loss
=
fsp_loss
(
distill_loss
=
fsp_loss
(
'teacher_conv5_bn_output.tmp_2'
,
'teacher_conv5_bn_output.tmp_2'
,
'teacher_conv6_bn_output.tmp_2'
,
'teacher_conv6_bn_output.tmp_2'
,
'conv1_bn_output.tmp_2'
,
'conv2_bn_output.tmp_2'
)
'conv1_bn_output.tmp_2'
,
'conv2_bn_output.tmp_2'
,
student_main
)
loss_ops
=
[]
loss_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
loss_ops
.
append
(
op
.
type
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
...
...
tests/test_l2_loss.py
浏览文件 @
3a33f1a4
...
@@ -15,7 +15,6 @@ import sys
...
@@ -15,7 +15,6 @@ import sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
unittest
import
unittest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
from
static_case
import
StaticCase
from
static_case
import
StaticCase
from
paddleslim.dist
import
merge
,
l2_loss
from
paddleslim.dist
import
merge
,
l2_loss
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
...
@@ -23,18 +22,15 @@ from layers import conv_bn_layer
...
@@ -23,18 +22,15 @@ from layers import conv_bn_layer
class
TestL2Loss
(
StaticCase
):
class
TestL2Loss
(
StaticCase
):
def
test_l2_loss
(
self
):
def
test_l2_loss
(
self
):
student_main
=
fluid
.
Program
()
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
student_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
student_main
,
student_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
conv1
+
conv2
teacher_main
=
fluid
.
Program
()
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
fluid
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_main
,
teacher_startup
):
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
sum1
=
conv1
+
conv2
...
@@ -44,18 +40,18 @@ class TestL2Loss(StaticCase):
...
@@ -44,18 +40,18 @@ class TestL2Loss(StaticCase):
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
place
=
fluid
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merge
(
teacher_main
,
paddle
.
static
.
default_main_program
(),
data_name_map
,
place
)
merged_ops
=
[]
merged_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
merged_ops
.
append
(
op
.
type
)
with
fluid
.
program_guard
(
student_main
):
distill_loss
=
l2_loss
(
'teacher_conv6_bn_output.tmp_2'
,
distill_loss
=
l2_loss
(
'teacher_conv6_bn_output.tmp_2'
,
'conv2_bn_output.tmp_2'
,
student_main
)
'conv2_bn_output.tmp_2'
)
loss_ops
=
[]
loss_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
loss_ops
.
append
(
op
.
type
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
...
...
tests/test_loss.py
浏览文件 @
3a33f1a4
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
unittest
import
unittest
import
paddle
.fluid
as
fluid
import
paddle
from
paddleslim.dist
import
merge
,
loss
from
paddleslim.dist
import
merge
,
loss
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
from
static_case
import
StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
class
TestLoss
(
StaticCase
):
class
TestLoss
(
StaticCase
):
def
test_loss
(
self
):
def
test_loss
(
self
):
student_main
=
fluid
.
Program
()
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
student_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
student_main
,
student_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
conv1
+
conv2
teacher_main
=
fluid
.
Program
()
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
fluid
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_main
,
teacher_startup
):
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
sum1
=
conv1
+
conv2
...
@@ -43,29 +40,26 @@ class TestLoss(StaticCase):
...
@@ -43,29 +40,26 @@ class TestLoss(StaticCase):
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
place
=
fluid
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merge
(
teacher_main
,
paddle
.
static
.
default_main_program
(),
data_name_map
,
place
)
merged_ops
=
[]
merged_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
merged_ops
.
append
(
op
.
type
)
def
adaptation_loss
(
t_var
,
s_var
):
def
adaptation_loss
(
t_var
,
s_var
):
teacher_channel
=
t_var
.
shape
[
1
]
hint_loss
=
paddle
.
mean
(
s_hint
=
fluid
.
layers
.
conv2d
(
s_var
,
teacher_channel
,
1
)
paddle
.
nn
.
functional
.
square_error_cost
(
s_var
,
t_var
))
hint_loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
s_hint
-
t_var
))
return
hint_loss
return
hint_loss
with
fluid
.
program_guard
(
student_main
):
distill_loss
=
loss
(
distill_loss
=
loss
(
adaptation_loss
,
adaptation_loss
,
student_main
,
t_var
=
'teacher_conv6_bn_output.tmp_2'
,
t_var
=
'teacher_conv6_bn_output.tmp_2'
,
s_var
=
'conv2_bn_output.tmp_2'
)
s_var
=
'conv2_bn_output.tmp_2'
)
loss_ops
=
[]
loss_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
loss_ops
.
append
(
op
.
type
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
...
...
tests/test_merge.py
浏览文件 @
3a33f1a4
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
unittest
import
unittest
import
paddle
.fluid
as
fluid
import
paddle
from
paddleslim.dist
import
merge
from
paddleslim.dist
import
merge
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
from
static_case
import
StaticCase
...
@@ -22,10 +22,10 @@ from static_case import StaticCase
...
@@ -22,10 +22,10 @@ from static_case import StaticCase
class
TestMerge
(
StaticCase
):
class
TestMerge
(
StaticCase
):
def
test_merge
(
self
):
def
test_merge
(
self
):
student_main
=
fluid
.
Program
()
student_main
=
paddle
.
static
.
Program
()
student_startup
=
fluid
.
Program
()
student_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
student_main
,
student_startup
):
with
paddle
.
static
.
program_guard
(
student_main
,
student_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
conv1
+
conv2
...
@@ -34,10 +34,10 @@ class TestMerge(StaticCase):
...
@@ -34,10 +34,10 @@ class TestMerge(StaticCase):
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
student_ops
.
append
(
op
)
student_ops
.
append
(
op
)
teacher_main
=
fluid
.
Program
()
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
fluid
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_main
,
teacher_startup
):
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
sum1
=
conv1
+
conv2
...
@@ -51,7 +51,7 @@ class TestMerge(StaticCase):
...
@@ -51,7 +51,7 @@ class TestMerge(StaticCase):
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
teacher_ops
.
append
(
op
)
teacher_ops
.
append
(
op
)
place
=
fluid
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merged_ops
=
[]
merged_ops
=
[]
...
...
tests/test_soft_label_loss.py
浏览文件 @
3a33f1a4
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
unittest
import
unittest
import
paddle
.fluid
as
fluid
import
paddle
from
paddleslim.dist
import
merge
,
soft_label_loss
from
paddleslim.dist
import
merge
,
soft_label_loss
from
layers
import
conv_bn_layer
from
layers
import
conv_bn_layer
from
static_case
import
StaticCase
from
static_case
import
StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
...
@@ -22,18 +22,15 @@ from static_case import StaticCase
class
TestSoftLabelLoss
(
StaticCase
):
class
TestSoftLabelLoss
(
StaticCase
):
def
test_soft_label_loss
(
self
):
def
test_soft_label_loss
(
self
):
student_main
=
fluid
.
Program
()
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
student_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
student_main
,
student_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
student_predict
=
conv1
+
conv2
student_predict
=
conv1
+
conv2
teacher_main
=
fluid
.
Program
()
teacher_main
=
paddle
.
static
.
Program
()
teacher_startup
=
fluid
.
Program
()
teacher_startup
=
paddle
.
static
.
Program
()
with
fluid
.
program_guard
(
teacher_main
,
teacher_startup
):
with
paddle
.
static
.
program_guard
(
teacher_main
,
teacher_startup
):
input
=
fluid
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
input
=
paddle
.
static
.
data
(
name
=
"image"
,
shape
=
[
None
,
3
,
224
,
224
])
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv1
=
conv_bn_layer
(
input
,
8
,
3
,
"conv1"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
conv2
=
conv_bn_layer
(
conv1
,
8
,
3
,
"conv2"
)
sum1
=
conv1
+
conv2
sum1
=
conv1
+
conv2
...
@@ -43,19 +40,18 @@ class TestSoftLabelLoss(StaticCase):
...
@@ -43,19 +40,18 @@ class TestSoftLabelLoss(StaticCase):
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
conv5
=
conv_bn_layer
(
sum2
,
8
,
3
,
"conv5"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
teacher_predict
=
conv_bn_layer
(
conv5
,
8
,
3
,
"conv6"
)
place
=
fluid
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
data_name_map
=
{
'image'
:
'image'
}
data_name_map
=
{
'image'
:
'image'
}
merge
(
teacher_main
,
student_main
,
data_name_map
,
place
)
merge
(
teacher_main
,
paddle
.
static
.
default_main_program
(),
data_name_map
,
place
)
merged_ops
=
[]
merged_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
merged_ops
.
append
(
op
.
type
)
merged_ops
.
append
(
op
.
type
)
with
fluid
.
program_guard
(
student_main
):
distill_loss
=
soft_label_loss
(
'teacher_conv6_bn_output.tmp_2'
,
distill_loss
=
soft_label_loss
(
'teacher_conv6_bn_output.tmp_2'
,
'conv2_bn_output.tmp_2'
,
'conv2_bn_output.tmp_2'
)
student_main
)
loss_ops
=
[]
loss_ops
=
[]
for
block
in
student_main
.
blocks
:
for
block
in
paddle
.
static
.
default_main_program
()
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
loss_ops
.
append
(
op
.
type
)
loss_ops
.
append
(
op
.
type
)
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
self
.
assertTrue
(
set
(
merged_ops
).
difference
(
set
(
loss_ops
))
==
set
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录