Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
70c6b708
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70c6b708
编写于
5月 19, 2020
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update bert based DARTS demo
上级
d3e51bfc
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
198 addition
and
100 deletion
+198
-100
paddleslim/common/meter.py
paddleslim/common/meter.py
+5
-1
paddleslim/nas/darts/architect.py
paddleslim/nas/darts/architect.py
+3
-0
paddleslim/nas/darts/search_space/conv_bert/cls.py
paddleslim/nas/darts/search_space/conv_bert/cls.py
+89
-23
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
+4
-6
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
+4
-4
paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py
...darts/search_space/conv_bert/model/transformer_encoder.py
+40
-22
paddleslim/nas/darts/train_search.py
paddleslim/nas/darts/train_search.py
+47
-38
paddleslim/teachers/bert/model/transformer_encoder.py
paddleslim/teachers/bert/model/transformer_encoder.py
+6
-6
未找到文件。
paddleslim/common/meter.py
浏览文件 @
70c6b708
...
@@ -16,8 +16,9 @@ __all__ = ['AvgrageMeter']
...
@@ -16,8 +16,9 @@ __all__ = ['AvgrageMeter']
class
AvgrageMeter
(
object
):
class
AvgrageMeter
(
object
):
def
__init__
(
self
):
def
__init__
(
self
,
format
=
"{}"
):
self
.
reset
()
self
.
reset
()
self
.
_format
=
format
def
reset
(
self
):
def
reset
(
self
):
self
.
avg
=
0
self
.
avg
=
0
...
@@ -28,3 +29,6 @@ class AvgrageMeter(object):
...
@@ -28,3 +29,6 @@ class AvgrageMeter(object):
self
.
sum
+=
val
*
n
self
.
sum
+=
val
*
n
self
.
cnt
+=
n
self
.
cnt
+=
n
self
.
avg
=
self
.
sum
/
self
.
cnt
self
.
avg
=
self
.
sum
/
self
.
cnt
def
__repr__
(
self
):
return
self
.
_format
.
format
(
self
.
avg
)
paddleslim/nas/darts/architect.py
浏览文件 @
70c6b708
...
@@ -56,7 +56,10 @@ class Architect(object):
...
@@ -56,7 +56,10 @@ class Architect(object):
else
:
else
:
loss
=
self
.
_backward_step
(
valid_data
)
loss
=
self
.
_backward_step
(
valid_data
)
self
.
optimizer
.
minimize
(
loss
)
self
.
optimizer
.
minimize
(
loss
)
# print("alphas gradient: {}".format(self.model.arch_parameters()[0].gradient()))
self
.
optimizer
.
clear_gradients
()
self
.
optimizer
.
clear_gradients
()
return
self
.
model
.
arch_parameters
()[
0
].
gradient
()
def
_backward_step
(
self
,
valid_data
):
def
_backward_step
(
self
,
valid_data
):
loss
=
self
.
model
.
loss
(
valid_data
)
loss
=
self
.
model
.
loss
(
valid_data
)
...
...
paddleslim/nas/darts/search_space/conv_bert/cls.py
浏览文件 @
70c6b708
...
@@ -50,7 +50,9 @@ class AdaBERTClassifier(Layer):
...
@@ -50,7 +50,9 @@ class AdaBERTClassifier(Layer):
beta
=
4
,
beta
=
4
,
conv_type
=
"conv_bn"
,
conv_type
=
"conv_bn"
,
search_layer
=
True
,
search_layer
=
True
,
teacher_model
=
None
):
teacher_model
=
None
,
alphas
=
None
,
k
=
None
):
super
(
AdaBERTClassifier
,
self
).
__init__
()
super
(
AdaBERTClassifier
,
self
).
__init__
()
self
.
_n_layer
=
n_layer
self
.
_n_layer
=
n_layer
self
.
_num_labels
=
num_labels
self
.
_num_labels
=
num_labels
...
@@ -60,6 +62,8 @@ class AdaBERTClassifier(Layer):
...
@@ -60,6 +62,8 @@ class AdaBERTClassifier(Layer):
self
.
_beta
=
beta
self
.
_beta
=
beta
self
.
_conv_type
=
conv_type
self
.
_conv_type
=
conv_type
self
.
_search_layer
=
search_layer
self
.
_search_layer
=
search_layer
self
.
_alphas
=
alphas
self
.
_k
=
k
print
(
print
(
"----------------------load teacher model and test----------------------------------------"
"----------------------load teacher model and test----------------------------------------"
)
)
...
@@ -86,20 +90,36 @@ class AdaBERTClassifier(Layer):
...
@@ -86,20 +90,36 @@ class AdaBERTClassifier(Layer):
bias_attr
=
fluid
.
ParamAttr
(
bias_attr
=
fluid
.
ParamAttr
(
name
=
"s_cls_out_%d_b"
%
i
,
name
=
"s_cls_out_%d_b"
%
i
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
fc
=
self
.
add_sublayer
(
"cls_fc_%d"
%
i
,
fc
)
fc
=
self
.
add_sublayer
(
"
s_
cls_fc_%d"
%
i
,
fc
)
self
.
cls_fc
.
append
(
fc
)
self
.
cls_fc
.
append
(
fc
)
def
forward
(
self
,
data_ids
):
def
forward
(
self
,
data_ids
,
alphas
=
None
,
k
=
None
):
src_ids
=
data_ids
[
0
]
src_ids
=
data_ids
[
0
]
position_ids
=
data_ids
[
1
]
position_ids
=
data_ids
[
1
]
sentence_ids
=
data_ids
[
2
]
sentence_ids
=
data_ids
[
2
]
return
self
.
student
(
src_ids
,
position_ids
,
sentence_ids
)
return
self
.
student
(
src_ids
,
position_ids
,
sentence_ids
,
alphas
=
self
.
_alphas
,
k
=
self
.
_k
)
def
arch_parameters
(
self
):
def
arch_parameters
(
self
):
return
self
.
student
.
arch_parameters
()
return
self
.
student
.
arch_parameters
()
def
model_parameters
(
self
):
model_parameters
=
[
p
for
p
in
self
.
student
.
parameters
()
if
p
.
name
not
in
[
a
.
name
for
a
in
self
.
arch_parameters
()]
]
return
model_parameters
def
genotype
(
self
):
def
genotype
(
self
):
return
self
.
arch_parameters
()
alphas
=
self
.
arch_parameters
()[
0
].
numpy
()
alphas
=
[
np
.
argmax
(
edge
)
for
edge
in
alphas
]
k
=
np
.
argmax
(
self
.
arch_parameters
()[
1
].
numpy
())
return
"layers: {}; edges: {} "
.
format
(
k
,
alphas
)
def
new
(
self
):
def
new
(
self
):
model_new
=
AdaBERTClassifier
(
model_new
=
AdaBERTClassifier
(
...
@@ -108,8 +128,7 @@ class AdaBERTClassifier(Layer):
...
@@ -108,8 +128,7 @@ class AdaBERTClassifier(Layer):
)
)
return
model_new
return
model_new
def
loss
(
self
,
data_ids
):
def
valid
(
self
,
data_ids
):
T
=
1.0
src_ids
=
data_ids
[
0
]
src_ids
=
data_ids
[
0
]
position_ids
=
data_ids
[
1
]
position_ids
=
data_ids
[
1
]
sentence_ids
=
data_ids
[
2
]
sentence_ids
=
data_ids
[
2
]
...
@@ -117,16 +136,57 @@ class AdaBERTClassifier(Layer):
...
@@ -117,16 +136,57 @@ class AdaBERTClassifier(Layer):
labels
=
data_ids
[
4
]
labels
=
data_ids
[
4
]
flops
=
[]
flops
=
[]
model_size
=
[]
model_size
=
[]
alphas
=
self
.
arch_parameters
()[
0
].
numpy
(
)
if
self
.
_alphas
is
None
else
self
.
_alphas
k
=
self
.
arch_parameters
()[
1
].
numpy
()
if
self
.
_k
is
None
else
self
.
_k
print
(
alphas
.
shape
)
print
(
k
.
shape
)
enc_outputs
,
next_sent_feats
,
k_i
=
self
.
student
(
enc_outputs
,
next_sent_feats
,
k_i
=
self
.
student
(
src_ids
,
src_ids
,
position_ids
,
position_ids
,
sentence_ids
,
sentence_ids
,
flops
=
flops
,
flops
=
flops
,
model_size
=
model_size
)
model_size
=
model_size
,
alphas
=
alphas
,
k
=
k
)
logits
=
self
.
cls_fc
[
-
1
](
next_sent_feats
[
-
1
])
probs
=
fluid
.
layers
.
softmax
(
logits
)
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
probs
,
label
=
labels
)
model_size
=
np
.
sum
(
model_size
)
flops
=
np
.
sum
(
flops
)
ret
=
{
"accuracy"
:
accuracy
.
numpy
(),
"model_size(MB)"
:
model_size
/
1e6
,
"FLOPs(M)"
:
flops
/
1e6
}
return
ret
def
loss
(
self
,
data_ids
):
T
=
1.0
src_ids
=
data_ids
[
0
]
position_ids
=
data_ids
[
1
]
sentence_ids
=
data_ids
[
2
]
input_mask
=
data_ids
[
3
]
labels
=
data_ids
[
4
]
flops
=
[]
model_size
=
[]
self
.
teacher
.
eval
()
self
.
teacher
.
eval
()
total_loss
,
t_logits
,
t_losses
,
accuracys
,
num_seqs
=
self
.
teacher
(
total_loss
,
t_logits
,
t_losses
,
accuracys
,
num_seqs
=
self
.
teacher
(
data_ids
)
data_ids
)
self
.
teacher
.
train
()
enc_outputs
,
next_sent_feats
,
k_i
=
self
.
student
(
src_ids
,
position_ids
,
sentence_ids
,
flops
=
flops
,
model_size
=
model_size
,
alphas
=
self
.
_alphas
,
k
=
self
.
_k
)
# define kd loss
# define kd loss
kd_losses
=
[]
kd_losses
=
[]
...
@@ -140,21 +200,16 @@ class AdaBERTClassifier(Layer):
...
@@ -140,21 +200,16 @@ class AdaBERTClassifier(Layer):
kd_weights
=
np
.
exp
(
kd_weights
-
np
.
max
(
kd_weights
))
kd_weights
=
np
.
exp
(
kd_weights
-
np
.
max
(
kd_weights
))
kd_weights
=
kd_weights
/
kd_weights
.
sum
(
axis
=
0
)
kd_weights
=
kd_weights
/
kd_weights
.
sum
(
axis
=
0
)
s_probs
=
None
for
i
in
range
(
len
(
next_sent_feats
)):
for
i
in
range
(
len
(
next_sent_feats
)):
j
=
int
(
np
.
ceil
(
i
*
(
float
(
len
(
t_logits
))
/
len
(
next_sent_feats
))))
j
=
int
(
np
.
ceil
(
i
*
(
float
(
len
(
t_logits
))
/
len
(
next_sent_feats
))))
t_logit
=
t_logits
[
j
]
t_logit
=
t_logits
[
j
]
s_sent_feat
=
next_sent_feats
[
i
]
s_sent_feat
=
next_sent_feats
[
i
]
fc
=
self
.
cls_fc
[
i
]
fc
=
self
.
cls_fc
[
i
]
s_sent_feat
=
fluid
.
layers
.
dropout
(
x
=
s_sent_feat
,
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
s_logits
=
fc
(
s_sent_feat
)
s_logits
=
fc
(
s_sent_feat
)
t_logit
.
stop_gradient
=
True
t_probs
=
fluid
.
layers
.
softmax
(
t_logit
)
t_probs
=
fluid
.
layers
.
softmax
(
t_logit
)
s_probs
=
fluid
.
layers
.
softmax
(
s_logits
)
s_probs
=
fluid
.
layers
.
softmax
(
s_logits
)
t_probs
.
stop_gradient
=
True
kd_loss
=
t_probs
*
fluid
.
layers
.
log
(
s_probs
/
T
)
kd_loss
=
t_probs
*
fluid
.
layers
.
log
(
s_probs
/
T
)
kd_loss
=
fluid
.
layers
.
reduce_sum
(
kd_loss
,
dim
=
1
)
kd_loss
=
fluid
.
layers
.
reduce_sum
(
kd_loss
,
dim
=
1
)
kd_loss
=
kd_loss
*
kd_weights
[
i
]
kd_loss
=
kd_loss
*
kd_weights
[
i
]
...
@@ -167,17 +222,28 @@ class AdaBERTClassifier(Layer):
...
@@ -167,17 +222,28 @@ class AdaBERTClassifier(Layer):
ce_loss
=
fluid
.
layers
.
cross_entropy
(
s_probs
,
labels
)
ce_loss
=
fluid
.
layers
.
cross_entropy
(
s_probs
,
labels
)
ce_loss
=
fluid
.
layers
.
reduce_mean
(
ce_loss
)
*
k_i
ce_loss
=
fluid
.
layers
.
reduce_mean
(
ce_loss
)
*
k_i
len_model_size
=
len
(
model_size
)
# define e loss
# define e loss
model_size
=
fluid
.
layers
.
sum
(
model_size
)
if
self
.
_alphas
is
not
None
:
# print("model_size: {}".format(model_size.numpy()/1e6))
flops
=
np
.
sum
(
flops
)
model_size
=
np
.
sum
(
model_size
)
else
:
flops
=
fluid
.
layers
.
sum
(
flops
)
model_size
=
fluid
.
layers
.
sum
(
model_size
)
model_size
=
model_size
/
self
.
student
.
max_model_size
()
model_size
=
model_size
/
self
.
student
.
max_model_size
()
flops
=
fluid
.
layers
.
sum
(
flops
)
/
self
.
student
.
max_flops
()
flops
=
flops
/
self
.
student
.
max_flops
()
e_loss
=
(
len
(
next_sent_feats
)
*
k_i
/
self
.
_n_layer
)
*
(
e_loss
=
(
flops
+
model_size
)
*
(
len
(
next_sent_feats
)
*
k_i
/
flops
+
model_size
)
self
.
_n_layer
)
print
(
"len(next_sent_feats): {}; k_i: {}; flops: {}; model_size: {}; len: {}"
.
format
(
len
(
next_sent_feats
),
k_i
,
flops
.
numpy
(),
model_size
.
numpy
(),
len_model_size
))
# define total loss
# define total loss
loss
=
(
1
-
self
.
_gamma
loss
=
(
1
-
self
.
_gamma
)
*
ce_loss
-
self
.
_gamma
*
kd_loss
+
self
.
_beta
*
e_loss
)
*
ce_loss
-
self
.
_gamma
*
kd_loss
+
self
.
_beta
*
e_loss
# print("ce_loss: {}; kd_loss: {}; e_loss: {}".format((
# 1 - gamma) * ce_loss.numpy(), -gamma * kd_loss.numpy(), beta *
# e_loss.numpy()))
return
loss
,
ce_loss
,
kd_loss
,
e_loss
return
loss
,
ce_loss
,
kd_loss
,
e_loss
# loss = ce_loss + self._beta * e_loss
# return loss, ce_loss, ce_loss, e_loss
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
浏览文件 @
70c6b708
...
@@ -110,10 +110,9 @@ class BertModelLayer(Layer):
...
@@ -110,10 +110,9 @@ class BertModelLayer(Layer):
position_ids
,
position_ids
,
sentence_ids
,
sentence_ids
,
flops
=
[],
flops
=
[],
model_size
=
[]):
model_size
=
[],
"""
alphas
=
None
,
forward
k
=
None
):
"""
src_emb
=
self
.
_src_emb
(
src_ids
)
src_emb
=
self
.
_src_emb
(
src_ids
)
pos_emb
=
self
.
_pos_emb
(
position_ids
)
pos_emb
=
self
.
_pos_emb
(
position_ids
)
sent_emb
=
self
.
_sent_emb
(
sentence_ids
)
sent_emb
=
self
.
_sent_emb
(
sentence_ids
)
...
@@ -122,9 +121,8 @@ class BertModelLayer(Layer):
...
@@ -122,9 +121,8 @@ class BertModelLayer(Layer):
emb_out
=
emb_out
+
sent_emb
emb_out
=
emb_out
+
sent_emb
emb_out
=
self
.
_emb_fac
(
emb_out
)
emb_out
=
self
.
_emb_fac
(
emb_out
)
enc_outputs
,
k_i
=
self
.
_encoder
(
enc_outputs
,
k_i
=
self
.
_encoder
(
emb_out
,
flops
=
flops
,
model_size
=
model_size
)
emb_out
,
flops
=
flops
,
model_size
=
model_size
,
alphas
=
alphas
,
k
=
k
)
if
not
self
.
return_pooled_out
:
if
not
self
.
return_pooled_out
:
return
enc_outputs
return
enc_outputs
...
...
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
浏览文件 @
70c6b708
...
@@ -57,12 +57,12 @@ class ClsModelLayer(Layer):
...
@@ -57,12 +57,12 @@ class ClsModelLayer(Layer):
fc
=
Linear
(
fc
=
Linear
(
input_dim
=
self
.
config
[
"hidden_size"
],
input_dim
=
self
.
config
[
"hidden_size"
],
output_dim
=
num_labels
,
output_dim
=
num_labels
,
param_attr
=
fluid
.
ParamA
ttr
(
param_attr
=
fluid
.
parama
ttr
(
name
=
"cls_out_%d_w"
%
i
,
name
=
"cls_out_%d_w"
%
i
,
initializer
=
fluid
.
initializer
.
TruncatedN
ormal
(
scale
=
0.02
)),
initializer
=
fluid
.
initializer
.
truncatedn
ormal
(
scale
=
0.02
)),
bias_attr
=
fluid
.
ParamA
ttr
(
bias_attr
=
fluid
.
parama
ttr
(
name
=
"cls_out_%d_b"
%
i
,
name
=
"cls_out_%d_b"
%
i
,
initializer
=
fluid
.
initializer
.
C
onstant
(
0.
)))
initializer
=
fluid
.
initializer
.
c
onstant
(
0.
)))
fc
=
self
.
add_sublayer
(
"cls_fc_%d"
%
i
,
fc
)
fc
=
self
.
add_sublayer
(
"cls_fc_%d"
%
i
,
fc
)
self
.
cls_fc
.
append
(
fc
)
self
.
cls_fc
.
append
(
fc
)
...
...
paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py
浏览文件 @
70c6b708
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
numpy
as
np
from
collections
import
Iterable
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -25,13 +26,13 @@ from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, Ba
...
@@ -25,13 +26,13 @@ from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, Ba
from
paddle.fluid.initializer
import
NormalInitializer
from
paddle.fluid.initializer
import
NormalInitializer
GConv_PRIMITIVES
=
[
GConv_PRIMITIVES
=
[
'
std_gconv_3'
,
'std_gconv_5'
,
'std_gconv_7'
,
'dil_gconv_3'
,
'dil_gconv_5
'
,
'
none'
,
'std_gconv_3'
,
'std_gconv_5'
,
'std_gconv_7'
,
'dil_gconv_3
'
,
'dil_gconv_
7'
,
'avg_pool_3'
,
'max_pool_3'
,
'none
'
,
'skip_connect'
'dil_gconv_
5'
,
'dil_gconv_7'
,
'avg_pool_3'
,
'max_pool_3
'
,
'skip_connect'
]
]
ConvBN_PRIMITIVES
=
[
ConvBN_PRIMITIVES
=
[
'std_conv_bn_3'
,
'std_conv_bn_5'
,
'std_conv_bn_7'
,
'dil_conv_bn_3'
,
'
none'
,
'
std_conv_bn_3'
,
'std_conv_bn_5'
,
'std_conv_bn_7'
,
'dil_conv_bn_3'
,
'dil_conv_bn_5'
,
'dil_conv_bn_7'
,
'avg_pool_3'
,
'max_pool_3'
,
'none'
,
'dil_conv_bn_5'
,
'dil_conv_bn_7'
,
'avg_pool_3'
,
'max_pool_3'
,
'skip_connect'
'skip_connect'
]
]
...
@@ -117,7 +118,11 @@ class MixedOp(fluid.dygraph.Layer):
...
@@ -117,7 +118,11 @@ class MixedOp(fluid.dygraph.Layer):
def
forward
(
self
,
x
,
weights
,
flops
=
[],
model_size
=
[]):
def
forward
(
self
,
x
,
weights
,
flops
=
[],
model_size
=
[]):
for
i
in
range
(
len
(
self
.
_ops
)):
for
i
in
range
(
len
(
self
.
_ops
)):
if
weights
[
i
].
numpy
()
!=
0
:
if
isinstance
(
weights
,
Iterable
):
weights_i
=
weights
[
i
]
else
:
weights_i
=
weights
[
i
].
numpy
()
if
weights_i
!=
0
:
flops
.
append
(
FLOPs
.
values
()[
i
]
*
weights
[
i
])
flops
.
append
(
FLOPs
.
values
()[
i
]
*
weights
[
i
])
model_size
.
append
(
ModelSize
.
values
()[
i
]
*
weights
[
i
])
model_size
.
append
(
ModelSize
.
values
()[
i
]
*
weights
[
i
])
return
self
.
_ops
[
i
](
x
)
*
weights
[
i
]
return
self
.
_ops
[
i
](
x
)
*
weights
[
i
]
...
@@ -166,6 +171,7 @@ class ConvBNRelu(fluid.dygraph.Layer):
...
@@ -166,6 +171,7 @@ class ConvBNRelu(fluid.dygraph.Layer):
use_cudnn
=
True
,
use_cudnn
=
True
,
name
=
None
):
name
=
None
):
super
(
ConvBNRelu
,
self
).
__init__
()
super
(
ConvBNRelu
,
self
).
__init__
()
self
.
_name
=
name
conv_std
=
(
2.0
/
conv_std
=
(
2.0
/
(
filter_size
[
0
]
*
filter_size
[
1
]
*
out_c
*
in_c
))
**
0.5
(
filter_size
[
0
]
*
filter_size
[
1
]
*
out_c
*
in_c
))
**
0.5
conv_param
=
fluid
.
ParamAttr
(
conv_param
=
fluid
.
ParamAttr
(
...
@@ -187,7 +193,7 @@ class ConvBNRelu(fluid.dygraph.Layer):
...
@@ -187,7 +193,7 @@ class ConvBNRelu(fluid.dygraph.Layer):
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
conv
=
self
.
conv
(
inputs
)
conv
=
self
.
conv
(
inputs
)
bn
=
self
.
bn
(
conv
)
bn
=
self
.
bn
(
conv
)
return
bn
return
conv
class
GateConv
(
fluid
.
dygraph
.
Layer
):
class
GateConv
(
fluid
.
dygraph
.
Layer
):
...
@@ -261,24 +267,30 @@ class Cell(fluid.dygraph.Layer):
...
@@ -261,24 +267,30 @@ class Cell(fluid.dygraph.Layer):
states
=
[
s0
,
s1
]
states
=
[
s0
,
s1
]
offset
=
0
offset
=
0
for
i
in
range
(
self
.
_steps
):
for
i
in
range
(
self
.
_steps
):
s
=
fluid
.
layers
.
sums
([
edges
=
[]
self
.
_ops
[
offset
+
j
](
h
,
for
j
,
h
in
enumerate
(
states
):
weights
[
offset
+
j
],
edge
=
self
.
_ops
[
offset
+
j
](
h
,
flops
=
flops
,
weights
[
offset
+
j
],
model_size
=
model_size
)
flops
=
flops
,
for
j
,
h
in
enumerate
(
states
)
model_size
=
model_size
)
])
edges
.
append
(
edge
)
s
=
edges
[
0
]
for
n
in
range
(
1
,
len
(
edges
)):
s
=
s
+
edges
[
n
]
# s = fluid.layers.sums(edges)
offset
+=
len
(
states
)
offset
+=
len
(
states
)
states
.
append
(
s
)
states
.
append
(
s
)
out
=
fluid
.
layers
.
sum
(
states
[
-
self
.
_steps
:])
states
=
states
[
-
self
.
_steps
:]
out
=
states
[
0
]
for
n
in
range
(
1
,
len
(
states
)):
out
=
out
+
states
[
n
]
# out = fluid.layers.sums(states[-self._steps:])
return
out
return
out
class
EncoderLayer
(
Layer
):
class
EncoderLayer
(
Layer
):
"""
encoder
"""
def
__init__
(
self
,
def
__init__
(
self
,
n_layer
,
n_layer
,
hidden_size
=
768
,
hidden_size
=
768
,
...
@@ -342,15 +354,17 @@ class EncoderLayer(Layer):
...
@@ -342,15 +354,17 @@ class EncoderLayer(Layer):
default_initializer
=
NormalInitializer
(
default_initializer
=
NormalInitializer
(
loc
=
0.0
,
scale
=
1e-3
))
loc
=
0.0
,
scale
=
1e-3
))
def
forward
(
self
,
enc_input
,
flops
=
[],
model_size
=
[]):
def
forward
(
self
,
enc_input
,
flops
=
[],
model_size
=
[]
,
alphas
=
None
,
k
=
None
):
tmp
=
fluid
.
layers
.
reshape
(
tmp
=
fluid
.
layers
.
reshape
(
enc_input
,
[
-
1
,
1
,
enc_input
.
shape
[
1
],
enc_input
,
[
-
1
,
1
,
enc_input
.
shape
[
1
],
self
.
_hidden_size
])
#(bs, 1, seq_len, hidden_size)
self
.
_hidden_size
])
#(bs, 1, seq_len, hidden_size)
tmp
=
self
.
conv0
(
tmp
)
# (bs, hidden_size, seq_len, 1)
tmp
=
self
.
conv0
(
tmp
)
# (bs, hidden_size, seq_len, 1)
alphas
=
gumbel_softmax
(
self
.
alphas
)
if
alphas
is
None
:
k
=
fluid
.
layers
.
reshape
(
gumbel_softmax
(
self
.
k
),
[
-
1
])
alphas
=
gumbel_softmax
(
self
.
alphas
)
if
k
is
None
:
k
=
fluid
.
layers
.
reshape
(
gumbel_softmax
(
self
.
k
),
[
-
1
])
outputs
=
[]
outputs
=
[]
s0
=
s1
=
tmp
s0
=
s1
=
tmp
...
@@ -364,7 +378,11 @@ class EncoderLayer(Layer):
...
@@ -364,7 +378,11 @@ class EncoderLayer(Layer):
enc_output
,
[
-
1
,
enc_output
.
shape
[
1
],
enc_output
,
[
-
1
,
enc_output
.
shape
[
1
],
self
.
_hidden_size
])
# (bs, seq_len, hidden_size)
self
.
_hidden_size
])
# (bs, seq_len, hidden_size)
outputs
.
append
(
enc_output
)
outputs
.
append
(
enc_output
)
if
self
.
_search_layer
and
k
[
i
].
numpy
()
!=
0
:
if
isinstance
(
k
,
Iterable
):
k_i
=
k
[
i
]
else
:
k_i
=
k
[
i
].
numpy
()
if
k_i
!=
0
:
outputs
[
-
1
]
=
outputs
[
-
1
]
*
k
[
i
]
outputs
[
-
1
]
=
outputs
[
-
1
]
*
k
[
i
]
return
outputs
,
k
[
i
]
return
outputs
,
k
[
i
]
return
outputs
,
1.0
return
outputs
,
1.0
paddleslim/nas/darts/train_search.py
浏览文件 @
70c6b708
...
@@ -18,10 +18,12 @@ from __future__ import print_function
...
@@ -18,10 +18,12 @@ from __future__ import print_function
__all__
=
[
'DARTSearch'
]
__all__
=
[
'DARTSearch'
]
import
math
import
logging
import
logging
from
itertools
import
izip
from
itertools
import
izip
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
from
...common
import
AvgrageMeter
,
get_logger
from
...common
import
AvgrageMeter
,
get_logger
from
.architect
import
Architect
from
.architect
import
Architect
...
@@ -41,6 +43,7 @@ class DARTSearch(object):
...
@@ -41,6 +43,7 @@ class DARTSearch(object):
model
,
model
,
train_reader
,
train_reader
,
valid_reader
,
valid_reader
,
test_reader
=
None
,
learning_rate
=
0.025
,
learning_rate
=
0.025
,
batchsize
=
64
,
batchsize
=
64
,
num_imgs
=
50000
,
num_imgs
=
50000
,
...
@@ -54,6 +57,7 @@ class DARTSearch(object):
...
@@ -54,6 +57,7 @@ class DARTSearch(object):
self
.
model
=
model
self
.
model
=
model
self
.
train_reader
=
train_reader
self
.
train_reader
=
train_reader
self
.
valid_reader
=
valid_reader
self
.
valid_reader
=
valid_reader
self
.
test_reader
=
test_reader
self
.
learning_rate
=
learning_rate
self
.
learning_rate
=
learning_rate
self
.
batchsize
=
batchsize
self
.
batchsize
=
batchsize
self
.
num_imgs
=
num_imgs
self
.
num_imgs
=
num_imgs
...
@@ -82,10 +86,13 @@ class DARTSearch(object):
...
@@ -82,10 +86,13 @@ class DARTSearch(object):
step_id
=
0
step_id
=
0
for
train_data
,
valid_data
in
izip
(
train_loader
(),
valid_loader
()):
for
train_data
,
valid_data
in
izip
(
train_loader
(),
valid_loader
()):
if
epoch
>=
self
.
epochs_no_archopt
:
if
epoch
>=
self
.
epochs_no_archopt
:
architect
.
step
(
train_data
,
valid_data
)
a
lphas_grad
=
a
rchitect
.
step
(
train_data
,
valid_data
)
loss
,
ce_loss
,
kd_loss
,
e_loss
=
self
.
model
.
loss
(
train_data
)
loss
,
ce_loss
,
kd_loss
,
e_loss
=
self
.
model
.
loss
(
train_data
)
if
math
.
isnan
(
e_loss
.
numpy
()):
print
(
"alphas_grad: {}"
.
format
(
alphas_grad
))
print
(
"alphas: {}"
.
format
(
self
.
model
.
arch_parameters
()[
0
]
.
numpy
()))
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
loss
=
self
.
model
.
scale_loss
(
loss
)
loss
=
self
.
model
.
scale_loss
(
loss
)
loss
.
backward
()
loss
.
backward
()
...
@@ -93,67 +100,55 @@ class DARTSearch(object):
...
@@ -93,67 +100,55 @@ class DARTSearch(object):
else
:
else
:
loss
.
backward
()
loss
.
backward
()
grad_clip
=
fluid
.
dygraph_grad_clip
.
GradClipByGlobalNorm
(
5
)
# grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5)
optimizer
.
minimize
(
loss
,
grad_clip
)
# optimizer.minimize(loss, grad_clip)
optimizer
.
minimize
(
loss
)
self
.
model
.
clear_gradients
()
self
.
model
.
clear_gradients
()
batch_size
=
train_data
[
0
].
shape
[
0
]
batch_size
=
train_data
[
0
].
shape
[
0
]
objs
.
update
(
loss
.
numpy
(),
batch_size
)
objs
.
update
(
loss
.
numpy
(),
batch_size
)
e_loss
=
e_loss
.
numpy
()
if
isinstance
(
e_loss
,
Variable
)
else
e_loss
ce_losses
.
update
(
ce_loss
.
numpy
(),
batch_size
)
ce_losses
.
update
(
ce_loss
.
numpy
(),
batch_size
)
kd_losses
.
update
(
kd_loss
.
numpy
(),
batch_size
)
kd_losses
.
update
(
kd_loss
.
numpy
(),
batch_size
)
e_losses
.
update
(
e_loss
.
numpy
()
,
batch_size
)
e_losses
.
update
(
e_loss
,
batch_size
)
if
step_id
%
self
.
log_freq
==
0
:
if
step_id
%
self
.
log_freq
==
0
:
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger
.
info
(
logger
.
info
(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}"
.
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}"
.
format
(
epoch
,
step_id
,
format
(
epoch
,
step_id
,
loss
.
numpy
(),
loss
.
numpy
(),
ce_loss
.
numpy
(),
kd_loss
.
numpy
(),
e_loss
.
numpy
()
))
ce_loss
.
numpy
(),
kd_loss
.
numpy
(),
e_loss
))
step_id
+=
1
step_id
+=
1
return
objs
.
avg
[
0
]
return
objs
.
avg
[
0
]
def
valid_one_epoch
(
self
,
valid_loader
,
epoch
):
def
valid_one_epoch
(
self
,
valid_loader
,
epoch
):
objs
=
AvgrageMeter
()
top1
=
AvgrageMeter
()
top5
=
AvgrageMeter
()
self
.
model
.
eval
()
self
.
model
.
eval
()
meters
=
{}
for
step_id
,
valid_data
in
enumerate
(
valid_loader
):
for
step_id
,
valid_data
in
enumerate
(
valid_loader
):
image
=
to_variable
(
image
)
ret
=
self
.
model
.
valid
(
valid_data
)
label
=
to_variable
(
label
)
for
key
,
value
in
ret
.
items
():
n
=
image
.
shape
[
0
]
if
key
not
in
meters
:
logits
=
self
.
model
(
image
)
meters
[
key
]
=
AvgrageMeter
()
prec1
=
fluid
.
layers
.
accuracy
(
input
=
logits
,
label
=
label
,
k
=
1
)
meters
[
key
].
update
(
value
,
1
)
prec5
=
fluid
.
layers
.
accuracy
(
input
=
logits
,
label
=
label
,
k
=
5
)
loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
,
label
))
objs
.
update
(
loss
.
numpy
(),
n
)
top1
.
update
(
prec1
.
numpy
(),
n
)
top5
.
update
(
prec5
.
numpy
(),
n
)
if
step_id
%
self
.
log_freq
==
0
:
if
step_id
%
self
.
log_freq
==
0
:
logger
.
info
(
logger
.
info
(
"Valid Epoch {}, Step {}, {}"
.
format
(
"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}"
.
epoch
,
step_id
,
meters
))
format
(
epoch
,
step_id
,
objs
.
avg
[
0
],
top1
.
avg
[
0
],
top5
.
avg
[
0
]))
return
top1
.
avg
[
0
]
def
train
(
self
):
def
train
(
self
):
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
model_parameters
=
[
model_parameters
=
self
.
model
.
model_parameters
()
p
for
p
in
self
.
model
.
parameters
()
logger
.
info
(
"parameter size in super net: {:.6f}M"
.
format
(
if
p
.
name
not
in
[
a
.
name
for
a
in
self
.
model
.
arch_parameters
()]
]
logger
.
info
(
"param size = {:.6f}MB"
.
format
(
count_parameters_in_MB
(
model_parameters
)))
count_parameters_in_MB
(
model_parameters
)))
step_per_epoch
=
int
(
self
.
num_imgs
*
0.5
/
self
.
batchsize
)
step_per_epoch
=
int
(
self
.
num_imgs
*
0.5
/
self
.
batchsize
)
if
self
.
unrolled
:
if
self
.
unrolled
:
step_per_epoch
*=
2
step_per_epoch
*=
2
learning_rate
=
fluid
.
dygraph
.
CosineDecay
(
learning_rate
=
fluid
.
dygraph
.
CosineDecay
(
self
.
learning_rate
,
step_per_epoch
,
self
.
num_epochs
)
self
.
learning_rate
,
step_per_epoch
,
self
.
num_epochs
)
optimizer
=
fluid
.
optimizer
.
MomentumOptimizer
(
optimizer
=
fluid
.
optimizer
.
MomentumOptimizer
(
learning_rate
,
learning_rate
,
0.9
,
0.9
,
...
@@ -167,6 +162,9 @@ class DARTSearch(object):
...
@@ -167,6 +162,9 @@ class DARTSearch(object):
self
.
train_reader
)
self
.
train_reader
)
self
.
valid_reader
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
self
.
valid_reader
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
self
.
valid_reader
)
self
.
valid_reader
)
if
self
.
test_reader
is
not
None
:
self
.
test_reader
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
self
.
test_reader
)
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
64
,
capacity
=
64
,
...
@@ -182,6 +180,17 @@ class DARTSearch(object):
...
@@ -182,6 +180,17 @@ class DARTSearch(object):
train_loader
.
set_batch_generator
(
self
.
train_reader
,
places
=
self
.
place
)
train_loader
.
set_batch_generator
(
self
.
train_reader
,
places
=
self
.
place
)
valid_loader
.
set_batch_generator
(
self
.
valid_reader
,
places
=
self
.
place
)
valid_loader
.
set_batch_generator
(
self
.
valid_reader
,
places
=
self
.
place
)
if
self
.
test_reader
is
not
None
:
test_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
64
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
)
test_loader
.
set_batch_generator
(
self
.
test_reader
,
places
=
self
.
place
)
else
:
test_loader
=
valid_loader
architect
=
Architect
(
self
.
model
,
learning_rate
,
architect
=
Architect
(
self
.
model
,
learning_rate
,
self
.
arch_learning_rate
,
self
.
place
,
self
.
arch_learning_rate
,
self
.
place
,
self
.
unrolled
)
self
.
unrolled
)
...
@@ -199,8 +208,8 @@ class DARTSearch(object):
...
@@ -199,8 +208,8 @@ class DARTSearch(object):
self
.
train_one_epoch
(
train_loader
,
valid_loader
,
architect
,
self
.
train_one_epoch
(
train_loader
,
valid_loader
,
architect
,
optimizer
,
epoch
)
optimizer
,
epoch
)
if
epoch
==
self
.
num_epochs
-
1
:
# valid_top1 = self.valid_one_epoch(valid_loader, epoch)
# if epoch == self.num_epochs - 1:
logger
.
info
(
"Epoch {}, valid_acc {:.6f}"
.
format
(
epoch
,
1
)
)
# self.valid_one_epoch(test_loader, epoch
)
if
save_parameters
:
#
if save_parameters:
fluid
.
save_dygraph
(
self
.
model
.
state_dict
(),
"./weights"
)
#
fluid.save_dygraph(self.model.state_dict(), "./weights")
paddleslim/teachers/bert/model/transformer_encoder.py
浏览文件 @
70c6b708
...
@@ -57,7 +57,7 @@ class PrePostProcessLayer(Layer):
...
@@ -57,7 +57,7 @@ class PrePostProcessLayer(Layer):
elif
cmd
==
"d"
:
# add dropout
elif
cmd
==
"d"
:
# add dropout
if
dropout_rate
:
if
dropout_rate
:
self
.
functors
.
append
(
lambda
x
:
fluid
.
layers
.
dropout
(
self
.
functors
.
append
(
lambda
x
:
fluid
.
layers
.
dropout
(
x
,
dropout_prob
=
dropout_rate
,
is_test
=
Fals
e
))
x
,
dropout_prob
=
dropout_rate
,
is_test
=
Tru
e
))
self
.
exec_order
+=
"d"
self
.
exec_order
+=
"d"
def
forward
(
self
,
x
,
residual
=
None
):
def
forward
(
self
,
x
,
residual
=
None
):
...
@@ -111,8 +111,8 @@ class PositionwiseFeedForwardLayer(Layer):
...
@@ -111,8 +111,8 @@ class PositionwiseFeedForwardLayer(Layer):
hidden
=
fluid
.
layers
.
dropout
(
hidden
=
fluid
.
layers
.
dropout
(
hidden
,
hidden
,
dropout_prob
=
self
.
_dropout_rate
,
dropout_prob
=
self
.
_dropout_rate
,
upscale_in_train
=
"upscale_in_train"
,
#
upscale_in_train="upscale_in_train",
is_test
=
Fals
e
)
is_test
=
Tru
e
)
out
=
self
.
_h2o
(
hidden
)
out
=
self
.
_h2o
(
hidden
)
return
out
return
out
...
@@ -218,13 +218,13 @@ class MultiHeadAttentionLayer(Layer):
...
@@ -218,13 +218,13 @@ class MultiHeadAttentionLayer(Layer):
#alpha=self._d_model**-0.5)
#alpha=self._d_model**-0.5)
if
attn_bias
is
not
None
:
if
attn_bias
is
not
None
:
product
+=
attn_bias
product
+=
attn_bias
weights
=
fluid
.
layers
.
softmax
(
product
)
weights
=
fluid
.
layers
.
softmax
(
product
)
# 48
if
self
.
_dropout_rate
:
if
self
.
_dropout_rate
:
weights_droped
=
fluid
.
layers
.
dropout
(
weights_droped
=
fluid
.
layers
.
dropout
(
weights
,
weights
,
dropout_prob
=
self
.
_dropout_rate
,
dropout_prob
=
self
.
_dropout_rate
,
dropout_implementation
=
"upscale_in_train"
,
#
dropout_implementation="upscale_in_train",
is_test
=
Fals
e
)
is_test
=
Tru
e
)
out
=
fluid
.
layers
.
matmul
(
weights_droped
,
transpose_v
)
out
=
fluid
.
layers
.
matmul
(
weights_droped
,
transpose_v
)
else
:
else
:
out
=
fluid
.
layers
.
matmul
(
weights
,
transpose_v
)
out
=
fluid
.
layers
.
matmul
(
weights
,
transpose_v
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录