Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d3e51bfc
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d3e51bfc
编写于
4月 20, 2020
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
1. Fix reader
2. Change gate conv to conv_bn_relu
上级
5ed02755
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
272 addition
and
95 deletion
+272
-95
paddleslim/nas/darts/architect.py
paddleslim/nas/darts/architect.py
+2
-2
paddleslim/nas/darts/search_space/conv_bert/cls.py
paddleslim/nas/darts/search_space/conv_bert/cls.py
+39
-13
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
+22
-6
paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py
...darts/search_space/conv_bert/model/transformer_encoder.py
+185
-69
paddleslim/nas/darts/train_search.py
paddleslim/nas/darts/train_search.py
+14
-3
paddleslim/teachers/bert/reader/cls.py
paddleslim/teachers/bert/reader/cls.py
+10
-2
未找到文件。
paddleslim/nas/darts/architect.py
浏览文件 @
d3e51bfc
...
@@ -60,8 +60,8 @@ class Architect(object):
...
@@ -60,8 +60,8 @@ class Architect(object):
def
_backward_step
(
self
,
valid_data
):
def
_backward_step
(
self
,
valid_data
):
loss
=
self
.
model
.
loss
(
valid_data
)
loss
=
self
.
model
.
loss
(
valid_data
)
loss
.
backward
()
loss
[
0
]
.
backward
()
return
loss
return
loss
[
0
]
def
_backward_step_unrolled
(
self
,
train_data
,
valid_data
):
def
_backward_step_unrolled
(
self
,
train_data
,
valid_data
):
self
.
_compute_unrolled_model
(
train_data
)
self
.
_compute_unrolled_model
(
train_data
)
...
...
paddleslim/nas/darts/search_space/conv_bert/cls.py
浏览文件 @
d3e51bfc
...
@@ -41,27 +41,44 @@ __all__ = ["AdaBERTClassifier"]
...
@@ -41,27 +41,44 @@ __all__ = ["AdaBERTClassifier"]
class
AdaBERTClassifier
(
Layer
):
class
AdaBERTClassifier
(
Layer
):
def
__init__
(
self
,
num_labels
,
n_layer
=
8
,
emb_size
=
768
,
def
__init__
(
self
,
num_labels
,
n_layer
=
8
,
emb_size
=
128
,
hidden_size
=
768
,
gamma
=
0.8
,
beta
=
4
,
conv_type
=
"conv_bn"
,
search_layer
=
True
,
teacher_model
=
None
):
teacher_model
=
None
):
super
(
AdaBERTClassifier
,
self
).
__init__
()
super
(
AdaBERTClassifier
,
self
).
__init__
()
self
.
_n_layer
=
n_layer
self
.
_n_layer
=
n_layer
self
.
_num_labels
=
num_labels
self
.
_num_labels
=
num_labels
self
.
_emb_size
=
emb_size
self
.
_emb_size
=
emb_size
self
.
_hidden_size
=
hidden_size
self
.
_gamma
=
gamma
self
.
_beta
=
beta
self
.
_conv_type
=
conv_type
self
.
_search_layer
=
search_layer
print
(
print
(
"----------------------load teacher model and test----------------------------------------"
"----------------------load teacher model and test----------------------------------------"
)
)
self
.
teacher
=
BERTClassifier
(
num_labels
,
model_path
=
teacher_model
)
self
.
teacher
=
BERTClassifier
(
num_labels
,
model_path
=
teacher_model
)
#
self.teacher.test("/work/PaddleSlim/demo/bert/data/glue_data/MNLI/")
self
.
teacher
.
test
(
"/work/PaddleSlim/demo/bert/data/glue_data/MNLI/"
)
print
(
print
(
"----------------------finish load teacher model and test----------------------------------------"
"----------------------finish load teacher model and test----------------------------------------"
)
)
self
.
student
=
BertModelLayer
(
self
.
student
=
BertModelLayer
(
n_layer
=
self
.
_n_layer
,
emb_size
=
self
.
_emb_size
)
n_layer
=
self
.
_n_layer
,
emb_size
=
self
.
_emb_size
,
hidden_size
=
self
.
_hidden_size
,
conv_type
=
self
.
_conv_type
,
search_layer
=
self
.
_search_layer
)
self
.
cls_fc
=
list
()
self
.
cls_fc
=
list
()
for
i
in
range
(
self
.
_n_layer
):
for
i
in
range
(
self
.
_n_layer
):
fc
=
Linear
(
fc
=
Linear
(
input_dim
=
self
.
_
emb
_size
,
input_dim
=
self
.
_
hidden
_size
,
output_dim
=
self
.
_num_labels
,
output_dim
=
self
.
_num_labels
,
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
name
=
"s_cls_out_%d_w"
%
i
,
name
=
"s_cls_out_%d_w"
%
i
,
...
@@ -84,7 +101,14 @@ class AdaBERTClassifier(Layer):
...
@@ -84,7 +101,14 @@ class AdaBERTClassifier(Layer):
def
genotype
(
self
):
def
genotype
(
self
):
return
self
.
arch_parameters
()
return
self
.
arch_parameters
()
def
loss
(
self
,
data_ids
,
beta
=
4
,
gamma
=
0.8
):
def
new
(
self
):
model_new
=
AdaBERTClassifier
(
3
,
teacher_model
=
"/work/PaddleSlim/demo/bert_1/checkpoints/steps_23000"
)
return
model_new
def
loss
(
self
,
data_ids
):
T
=
1.0
T
=
1.0
src_ids
=
data_ids
[
0
]
src_ids
=
data_ids
[
0
]
position_ids
=
data_ids
[
1
]
position_ids
=
data_ids
[
1
]
...
@@ -130,7 +154,7 @@ class AdaBERTClassifier(Layer):
...
@@ -130,7 +154,7 @@ class AdaBERTClassifier(Layer):
t_probs
=
fluid
.
layers
.
softmax
(
t_logit
)
t_probs
=
fluid
.
layers
.
softmax
(
t_logit
)
s_probs
=
fluid
.
layers
.
softmax
(
s_logits
)
s_probs
=
fluid
.
layers
.
softmax
(
s_logits
)
t_probs
.
stop_gradient
=
Fals
e
t_probs
.
stop_gradient
=
Tru
e
kd_loss
=
t_probs
*
fluid
.
layers
.
log
(
s_probs
/
T
)
kd_loss
=
t_probs
*
fluid
.
layers
.
log
(
s_probs
/
T
)
kd_loss
=
fluid
.
layers
.
reduce_sum
(
kd_loss
,
dim
=
1
)
kd_loss
=
fluid
.
layers
.
reduce_sum
(
kd_loss
,
dim
=
1
)
kd_loss
=
kd_loss
*
kd_weights
[
i
]
kd_loss
=
kd_loss
*
kd_weights
[
i
]
...
@@ -144,14 +168,16 @@ class AdaBERTClassifier(Layer):
...
@@ -144,14 +168,16 @@ class AdaBERTClassifier(Layer):
ce_loss
=
fluid
.
layers
.
reduce_mean
(
ce_loss
)
*
k_i
ce_loss
=
fluid
.
layers
.
reduce_mean
(
ce_loss
)
*
k_i
# define e loss
# define e loss
model_size
=
fluid
.
layers
.
sum
(
model_size
=
fluid
.
layers
.
sum
(
model_size
)
model_size
)
/
self
.
student
.
max_model_size
()
# print("model_size: {}".format(model_size.numpy()/1e6))
model_size
=
model_size
/
self
.
student
.
max_model_size
()
flops
=
fluid
.
layers
.
sum
(
flops
)
/
self
.
student
.
max_flops
()
flops
=
fluid
.
layers
.
sum
(
flops
)
/
self
.
student
.
max_flops
()
e_loss
=
(
len
(
next_sent_feats
)
*
k_i
/
self
.
_n_layer
)
*
(
e_loss
=
(
len
(
next_sent_feats
)
*
k_i
/
self
.
_n_layer
)
*
(
flops
+
model_size
)
flops
+
model_size
)
# define total loss
# define total loss
loss
=
(
1
-
gamma
)
*
ce_loss
-
gamma
*
kd_loss
+
beta
*
e_loss
loss
=
(
1
-
self
.
_gamma
print
(
"ce_loss: {}; kd_loss: {}; e_loss: {}"
.
format
((
)
*
ce_loss
-
self
.
_gamma
*
kd_loss
+
self
.
_beta
*
e_loss
1
-
gamma
)
*
ce_loss
.
numpy
(),
-
gamma
*
kd_loss
.
numpy
(),
beta
*
# print("ce_loss: {}; kd_loss: {}; e_loss: {}".format((
e_loss
.
numpy
()))
# 1 - gamma) * ce_loss.numpy(), -gamma * kd_loss.numpy(), beta *
return
loss
# e_loss.numpy()))
return
loss
,
ce_loss
,
kd_loss
,
e_loss
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
浏览文件 @
d3e51bfc
...
@@ -28,17 +28,21 @@ from .transformer_encoder import EncoderLayer
...
@@ -28,17 +28,21 @@ from .transformer_encoder import EncoderLayer
class
BertModelLayer
(
Layer
):
class
BertModelLayer
(
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
emb_size
=
768
,
emb_size
=
128
,
hidden_size
=
768
,
n_layer
=
12
,
n_layer
=
12
,
voc_size
=
30522
,
voc_size
=
30522
,
max_position_seq_len
=
512
,
max_position_seq_len
=
512
,
sent_types
=
2
,
sent_types
=
2
,
return_pooled_out
=
True
,
return_pooled_out
=
True
,
initializer_range
=
1.0
,
initializer_range
=
1.0
,
conv_type
=
"conv_bn"
,
search_layer
=
True
,
use_fp16
=
False
):
use_fp16
=
False
):
super
(
BertModelLayer
,
self
).
__init__
()
super
(
BertModelLayer
,
self
).
__init__
()
self
.
_emb_size
=
emb_size
self
.
_emb_size
=
emb_size
self
.
_hidden_size
=
hidden_size
self
.
_n_layer
=
n_layer
self
.
_n_layer
=
n_layer
self
.
_voc_size
=
voc_size
self
.
_voc_size
=
voc_size
self
.
_max_position_seq_len
=
max_position_seq_len
self
.
_max_position_seq_len
=
max_position_seq_len
...
@@ -50,6 +54,8 @@ class BertModelLayer(Layer):
...
@@ -50,6 +54,8 @@ class BertModelLayer(Layer):
self
.
_sent_emb_name
=
"s_sent_embedding"
self
.
_sent_emb_name
=
"s_sent_embedding"
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
self
.
_conv_type
=
conv_type
self
.
_search_layer
=
search_layer
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
initializer_range
)
scale
=
initializer_range
)
...
@@ -71,16 +77,24 @@ class BertModelLayer(Layer):
...
@@ -71,16 +77,24 @@ class BertModelLayer(Layer):
name
=
self
.
_sent_emb_name
,
initializer
=
self
.
_param_initializer
),
name
=
self
.
_sent_emb_name
,
initializer
=
self
.
_param_initializer
),
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
)
self
.
pooled_f
c
=
Linear
(
self
.
_emb_fa
c
=
Linear
(
input_dim
=
self
.
_emb_size
,
input_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_hidden_size
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"s_emb_factorization"
))
self
.
pooled_fc
=
Linear
(
input_dim
=
self
.
_hidden_size
,
output_dim
=
self
.
_hidden_size
,
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
name
=
"s_pooled_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
name
=
"s_pooled_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"s_pooled_fc.b_0"
,
bias_attr
=
"s_pooled_fc.b_0"
,
act
=
"tanh"
)
act
=
"tanh"
)
self
.
_encoder
=
EncoderLayer
(
self
.
_encoder
=
EncoderLayer
(
n_layer
=
self
.
_n_layer
,
d_model
=
self
.
_emb_size
)
n_layer
=
self
.
_n_layer
,
hidden_size
=
self
.
_hidden_size
,
conv_type
=
self
.
_conv_type
,
search_layer
=
self
.
_search_layer
)
def
max_flops
(
self
):
def
max_flops
(
self
):
return
self
.
_encoder
.
max_flops
return
self
.
_encoder
.
max_flops
...
@@ -89,7 +103,7 @@ class BertModelLayer(Layer):
...
@@ -89,7 +103,7 @@ class BertModelLayer(Layer):
return
self
.
_encoder
.
max_model_size
return
self
.
_encoder
.
max_model_size
def
arch_parameters
(
self
):
def
arch_parameters
(
self
):
return
[
self
.
_encoder
.
alphas
]
return
[
self
.
_encoder
.
alphas
,
self
.
_encoder
.
k
]
def
forward
(
self
,
def
forward
(
self
,
src_ids
,
src_ids
,
...
@@ -107,6 +121,8 @@ class BertModelLayer(Layer):
...
@@ -107,6 +121,8 @@ class BertModelLayer(Layer):
emb_out
=
src_emb
+
pos_emb
emb_out
=
src_emb
+
pos_emb
emb_out
=
emb_out
+
sent_emb
emb_out
=
emb_out
+
sent_emb
emb_out
=
self
.
_emb_fac
(
emb_out
)
enc_outputs
,
k_i
=
self
.
_encoder
(
enc_outputs
,
k_i
=
self
.
_encoder
(
emb_out
,
flops
=
flops
,
model_size
=
model_size
)
emb_out
,
flops
=
flops
,
model_size
=
model_size
)
...
@@ -118,7 +134,7 @@ class BertModelLayer(Layer):
...
@@ -118,7 +134,7 @@ class BertModelLayer(Layer):
input
=
enc_output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
input
=
enc_output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
next_sent_feat
=
self
.
pooled_fc
(
next_sent_feat
)
next_sent_feat
=
self
.
pooled_fc
(
next_sent_feat
)
next_sent_feat
=
fluid
.
layers
.
reshape
(
next_sent_feat
=
fluid
.
layers
.
reshape
(
next_sent_feat
,
shape
=
[
-
1
,
self
.
_
emb
_size
])
next_sent_feat
,
shape
=
[
-
1
,
self
.
_
hidden
_size
])
next_sent_feats
.
append
(
next_sent_feat
)
next_sent_feats
.
append
(
next_sent_feat
)
return
enc_outputs
,
next_sent_feats
,
k_i
return
enc_outputs
,
next_sent_feats
,
k_i
paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py
浏览文件 @
d3e51bfc
...
@@ -24,57 +24,92 @@ import paddle.fluid as fluid
...
@@ -24,57 +24,92 @@ import paddle.fluid as fluid
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
Layer
,
Conv2D
,
BatchNorm
,
Pool2D
,
to_variable
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
Layer
,
Conv2D
,
BatchNorm
,
Pool2D
,
to_variable
from
paddle.fluid.initializer
import
NormalInitializer
from
paddle.fluid.initializer
import
NormalInitializer
PRIMITIVES
=
[
GConv_
PRIMITIVES
=
[
'std_
conv_3'
,
'std_conv_5'
,
'std_conv_7'
,
'dil_conv_3'
,
'dil_
conv_5'
,
'std_
gconv_3'
,
'std_gconv_5'
,
'std_gconv_7'
,
'dil_gconv_3'
,
'dil_g
conv_5'
,
'dil_conv_7'
,
'avg_pool_3'
,
'max_pool_3'
,
'none'
,
'skip_connect'
'dil_
g
conv_7'
,
'avg_pool_3'
,
'max_pool_3'
,
'none'
,
'skip_connect'
]
]
input_size
=
128
*
768
ConvBN_PRIMITIVES
=
[
'std_conv_bn_3'
,
'std_conv_bn_5'
,
'std_conv_bn_7'
,
'dil_conv_bn_3'
,
'dil_conv_bn_5'
,
'dil_conv_bn_7'
,
'avg_pool_3'
,
'max_pool_3'
,
'none'
,
'skip_connect'
]
channel
=
768
input_size
=
128
*
1
FLOPs
=
{
FLOPs
=
{
'std_conv_3'
:
input_size
*
3
*
1
,
'std_conv_bn_3'
:
input_size
*
(
channel
**
2
)
*
3
,
'std_conv_5'
:
input_size
*
5
*
1
,
'std_conv_bn_5'
:
input_size
*
(
channel
**
2
)
*
5
,
'std_conv_7'
:
input_size
*
7
*
1
,
'std_conv_bn_7'
:
input_size
*
(
channel
**
2
)
*
7
,
'dil_conv_3'
:
input_size
*
3
*
1
,
'dil_conv_bn_3'
:
input_size
*
(
channel
**
2
)
*
3
,
'dil_conv_5'
:
input_size
*
5
*
1
,
'dil_conv_bn_5'
:
input_size
*
(
channel
**
2
)
*
5
,
'dil_conv_7'
:
input_size
*
7
*
1
,
'dil_conv_bn_7'
:
input_size
*
(
channel
**
2
)
*
7
,
'avg_pool_3'
:
input_size
*
3
*
1
,
'std_gconv_3'
:
input_size
*
(
channel
**
2
)
*
3
,
'max_pool_3'
:
input_size
*
3
*
1
,
'std_gconv_5'
:
input_size
*
(
channel
**
2
)
*
5
,
'std_gconv_7'
:
input_size
*
(
channel
**
2
)
*
7
,
'dil_gconv_3'
:
input_size
*
(
channel
**
2
)
*
3
,
'dil_gconv_5'
:
input_size
*
(
channel
**
2
)
*
5
,
'dil_gconv_7'
:
input_size
*
(
channel
**
2
)
*
7
,
'avg_pool_3'
:
input_size
*
channel
*
3
*
1
,
'max_pool_3'
:
input_size
*
channel
*
3
*
1
,
'none'
:
0
,
'none'
:
0
,
'skip_connect'
:
0
,
'skip_connect'
:
0
,
}
}
ModelSize
=
{
ModelSize
=
{
'std_conv_3'
:
3
*
1
,
'std_conv_bn_3'
:
(
channel
**
2
)
*
3
*
1
,
'std_conv_5'
:
5
*
1
,
'std_conv_bn_5'
:
(
channel
**
2
)
*
5
*
1
,
'std_conv_7'
:
7
*
1
,
'std_conv_bn_7'
:
(
channel
**
2
)
*
7
*
1
,
'dil_conv_3'
:
3
*
1
,
'dil_conv_bn_3'
:
(
channel
**
2
)
*
3
*
1
,
'dil_conv_5'
:
5
*
1
,
'dil_conv_bn_5'
:
(
channel
**
2
)
*
5
*
1
,
'dil_conv_7'
:
7
*
1
,
'dil_conv_bn_7'
:
(
channel
**
2
)
*
7
*
1
,
'std_gconv_3'
:
(
channel
**
2
)
*
3
*
1
,
'std_gconv_5'
:
(
channel
**
2
)
*
5
*
1
,
'std_gconv_7'
:
(
channel
**
2
)
*
7
*
1
,
'dil_gconv_3'
:
(
channel
**
2
)
*
3
*
1
,
'dil_gconv_5'
:
(
channel
**
2
)
*
5
*
1
,
'dil_gconv_7'
:
(
channel
**
2
)
*
7
*
1
,
'avg_pool_3'
:
0
,
'avg_pool_3'
:
0
,
'max_pool_3'
:
0
,
'max_pool_3'
:
0
,
'none'
:
0
,
'none'
:
0
,
'skip_connect'
:
0
,
'skip_connect'
:
0
,
}
}
OPS
=
{
OPS
=
{
'std_conv_3'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
3
,
dilation
=
1
),
'std_gconv_3'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
3
,
1
],
dilation
=
1
,
name
=
name
),
'std_conv_5'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
5
,
dilation
=
1
),
'std_gconv_5'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
5
,
1
],
dilation
=
1
,
name
=
name
),
'std_conv_7'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
7
,
dilation
=
1
),
'std_gconv_7'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
7
,
1
],
dilation
=
1
,
name
=
name
),
'dil_conv_3'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
3
,
dilation
=
2
),
'dil_gconv_3'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
3
,
1
],
dilation
=
2
,
name
=
name
),
'dil_conv_5'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
5
,
dilation
=
2
),
'dil_gconv_5'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
5
,
1
],
dilation
=
2
,
name
=
name
),
'dil_conv_7'
:
lambda
:
ConvBN
(
1
,
1
,
filter_size
=
7
,
dilation
=
2
),
'dil_gconv_7'
:
lambda
n_channel
,
name
:
GateConv
(
n_channel
,
n_channel
,
filter_size
=
[
7
,
1
],
dilation
=
2
,
name
=
name
),
'avg_pool_3'
:
lambda
:
Pool2D
(
pool_size
=
(
3
,
1
),
pool_padding
=
(
1
,
0
),
pool_type
=
'avg'
),
'std_conv_bn_3'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
3
,
1
],
dilation
=
1
,
name
=
name
),
'max_pool_3'
:
lambda
:
Pool2D
(
pool_size
=
(
3
,
1
),
pool_padding
=
(
1
,
0
),
pool_type
=
'max'
),
'std_conv_bn_5'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
5
,
1
],
dilation
=
1
,
name
=
name
),
'none'
:
lambda
:
Zero
(),
'std_conv_bn_7'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
7
,
1
],
dilation
=
1
,
name
=
name
),
'skip_connect'
:
lambda
:
Identity
(),
'dil_conv_bn_3'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
3
,
1
],
dilation
=
2
,
name
=
name
),
'dil_conv_bn_5'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
5
,
1
],
dilation
=
2
,
name
=
name
),
'dil_conv_bn_7'
:
lambda
n_channel
,
name
:
ConvBNRelu
(
n_channel
,
n_channel
,
filter_size
=
[
7
,
1
],
dilation
=
2
,
name
=
name
),
'avg_pool_3'
:
lambda
n_channel
,
name
:
Pool2D
(
pool_size
=
(
3
,
1
),
pool_padding
=
(
1
,
0
),
pool_type
=
'avg'
),
'max_pool_3'
:
lambda
n_channel
,
name
:
Pool2D
(
pool_size
=
(
3
,
1
),
pool_padding
=
(
1
,
0
),
pool_type
=
'max'
),
'none'
:
lambda
n_channel
,
name
:
Zero
(),
'skip_connect'
:
lambda
n_channel
,
name
:
Identity
(),
}
}
class
MixedOp
(
fluid
.
dygraph
.
Layer
):
class
MixedOp
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
def
__init__
(
self
,
n_channel
,
name
=
None
,
conv_type
=
"conv_bn"
):
super
(
MixedOp
,
self
).
__init__
()
super
(
MixedOp
,
self
).
__init__
()
ops
=
[
OPS
[
primitive
]()
for
primitive
in
PRIMITIVES
]
if
conv_type
==
"conv_bn"
:
PRIMITIVES
=
ConvBN_PRIMITIVES
elif
conv_type
==
"gconv"
:
PRIMITIVES
=
GConv_PRIMITIVES
ops
=
[
OPS
[
primitive
](
n_channel
,
name
if
name
is
None
else
name
+
"/"
+
primitive
)
for
primitive
in
PRIMITIVES
]
self
.
_ops
=
fluid
.
dygraph
.
LayerList
(
ops
)
self
.
_ops
=
fluid
.
dygraph
.
LayerList
(
ops
)
self
.
max_flops
=
max
([
FLOPs
[
primitive
]
for
primitive
in
PRIMITIVES
])
self
.
max_flops
=
max
([
FLOPs
[
primitive
]
for
primitive
in
PRIMITIVES
])
self
.
max_model_size
=
max
(
self
.
max_model_size
=
max
(
...
@@ -121,39 +156,88 @@ def gumbel_softmax(logits, temperature=0.1, hard=True, eps=1e-20):
...
@@ -121,39 +156,88 @@ def gumbel_softmax(logits, temperature=0.1, hard=True, eps=1e-20):
return
out
return
out
class
ConvBN
(
fluid
.
dygraph
.
Layer
):
class
ConvBN
Relu
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
out_ch
,
in_c
=
768
,
in_ch
,
out_c
=
768
,
filter_size
=
3
,
filter_size
=
[
3
,
1
]
,
dilation
=
1
,
dilation
=
1
,
act
=
"relu"
,
is_test
=
False
,
is_test
=
False
,
use_cudnn
=
True
):
use_cudnn
=
True
,
super
(
ConvBN
,
self
).
__init__
()
name
=
None
):
conv_std
=
(
2.0
/
(
filter_size
**
2
*
in_ch
))
**
0.5
super
(
ConvBNRelu
,
self
).
__init__
()
conv_std
=
(
2.0
/
(
filter_size
[
0
]
*
filter_size
[
1
]
*
out_c
*
in_c
))
**
0.5
conv_param
=
fluid
.
ParamAttr
(
conv_param
=
fluid
.
ParamAttr
(
name
=
name
if
name
is
None
else
(
name
+
"_conv.weights"
),
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std
))
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std
))
self
.
conv_layer
=
Conv2D
(
self
.
conv
=
Conv2D
(
in_ch
,
in_c
,
out_ch
,
[
filter_size
,
1
],
out_c
,
filter_size
,
dilation
=
[
dilation
,
1
],
dilation
=
[
dilation
,
1
],
padding
=
[(
filter_size
-
1
)
*
dilation
//
2
,
0
],
padding
=
[(
filter_size
[
0
]
-
1
)
*
dilation
//
2
,
0
],
param_attr
=
conv_param
,
param_attr
=
conv_param
,
bias_attr
=
False
,
act
=
None
,
act
=
None
,
bias_attr
=
False
,
use_cudnn
=
use_cudnn
)
use_cudnn
=
use_cudnn
)
self
.
bn
_layer
=
BatchNorm
(
out_ch
,
act
=
act
,
is_test
=
is_test
)
self
.
bn
=
BatchNorm
(
out_c
,
act
=
"relu"
,
is_test
=
False
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
conv
=
self
.
conv
_layer
(
inputs
)
conv
=
self
.
conv
(
inputs
)
bn
=
self
.
bn
_layer
(
conv
)
bn
=
self
.
bn
(
conv
)
return
bn
return
bn
class
GateConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
in_c
=
768
,
out_c
=
768
,
filter_size
=
[
3
,
1
],
dilation
=
1
,
is_test
=
False
,
use_cudnn
=
True
,
name
=
None
):
super
(
GateConv
,
self
).
__init__
()
conv_std
=
(
2.0
/
(
filter_size
[
0
]
*
filter_size
[
1
]
*
out_c
*
in_c
))
**
0.5
conv_param
=
fluid
.
ParamAttr
(
name
=
name
if
name
is
None
else
(
name
+
"_conv.weights"
),
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std
))
gate_param
=
fluid
.
ParamAttr
(
name
=
name
if
name
is
None
else
(
name
+
"_conv_gate.weights"
),
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std
))
self
.
conv
=
Conv2D
(
in_c
,
out_c
,
filter_size
,
dilation
=
[
dilation
,
1
],
padding
=
[(
filter_size
[
0
]
-
1
)
*
dilation
//
2
,
0
],
param_attr
=
conv_param
,
act
=
None
,
use_cudnn
=
use_cudnn
)
self
.
gate
=
Conv2D
(
in_c
,
out_c
,
filter_size
,
dilation
=
[
dilation
,
1
],
padding
=
[(
filter_size
[
0
]
-
1
)
*
dilation
//
2
,
0
],
param_attr
=
gate_param
,
act
=
"sigmoid"
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
conv
=
self
.
conv
(
inputs
)
gate
=
self
.
gate
(
inputs
)
return
conv
*
gate
class
Cell
(
fluid
.
dygraph
.
Layer
):
class
Cell
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
steps
):
def
__init__
(
self
,
steps
,
n_channel
,
name
=
None
,
conv_type
=
"conv_bn"
):
super
(
Cell
,
self
).
__init__
()
super
(
Cell
,
self
).
__init__
()
self
.
_steps
=
steps
self
.
_steps
=
steps
...
@@ -162,7 +246,11 @@ class Cell(fluid.dygraph.Layer):
...
@@ -162,7 +246,11 @@ class Cell(fluid.dygraph.Layer):
ops
=
[]
ops
=
[]
for
i
in
range
(
self
.
_steps
):
for
i
in
range
(
self
.
_steps
):
for
j
in
range
(
2
+
i
):
for
j
in
range
(
2
+
i
):
op
=
MixedOp
()
op
=
MixedOp
(
n_channel
,
name
=
name
if
name
is
None
else
"%s/step%d_edge%d"
%
(
name
,
i
,
j
),
conv_type
=
conv_type
)
self
.
max_flops
+=
op
.
max_flops
self
.
max_flops
+=
op
.
max_flops
self
.
max_model_size
+=
op
.
max_model_size
self
.
max_model_size
+=
op
.
max_model_size
ops
.
append
(
op
)
ops
.
append
(
op
)
...
@@ -191,19 +279,49 @@ class EncoderLayer(Layer):
...
@@ -191,19 +279,49 @@ class EncoderLayer(Layer):
encoder
encoder
"""
"""
def
__init__
(
self
,
n_layer
,
d_model
=
128
,
name
=
""
):
def
__init__
(
self
,
n_layer
,
hidden_size
=
768
,
name
=
"encoder"
,
conv_type
=
"conv_bn"
,
search_layer
=
True
):
super
(
EncoderLayer
,
self
).
__init__
()
super
(
EncoderLayer
,
self
).
__init__
()
cells
=
[]
cells
=
[]
self
.
_n_layer
=
n_layer
self
.
_n_layer
=
n_layer
self
.
_
d_model
=
d_model
self
.
_
hidden_size
=
hidden_size
self
.
_steps
=
3
self
.
_steps
=
3
self
.
_search_layer
=
search_layer
self
.
max_flops
=
0
self
.
max_flops
=
0
self
.
max_model_size
=
0
self
.
max_model_size
=
0
if
conv_type
==
"conv_bn"
:
self
.
_n_ops
=
len
(
ConvBN_PRIMITIVES
)
self
.
conv0
=
ConvBNRelu
(
in_c
=
1
,
out_c
=
self
.
_hidden_size
,
filter_size
=
[
3
,
self
.
_hidden_size
],
dilation
=
1
,
is_test
=
False
,
use_cudnn
=
True
,
name
=
"conv0"
)
elif
conv_type
==
"gconv"
:
self
.
_n_ops
=
len
(
GConv_PRIMITIVES
)
self
.
conv0
=
GateConv
(
in_c
=
1
,
out_c
=
self
.
_hidden_size
,
filter_size
=
[
3
,
self
.
_hidden_size
],
dilation
=
1
,
is_test
=
False
,
use_cudnn
=
True
,
name
=
"conv0"
)
cells
=
[]
cells
=
[]
for
i
in
range
(
n_layer
):
for
i
in
range
(
n_layer
):
cell
=
Cell
(
steps
=
self
.
_steps
)
cell
=
Cell
(
steps
=
self
.
_steps
,
n_channel
=
self
.
_hidden_size
,
name
=
"%s/layer_%d"
%
(
name
,
i
),
conv_type
=
conv_type
)
cells
.
append
(
cell
)
cells
.
append
(
cell
)
self
.
max_flops
+=
cell
.
max_flops
self
.
max_flops
+=
cell
.
max_flops
self
.
max_model_size
+=
cell
.
max_model_size
self
.
max_model_size
+=
cell
.
max_model_size
...
@@ -211,7 +329,7 @@ class EncoderLayer(Layer):
...
@@ -211,7 +329,7 @@ class EncoderLayer(Layer):
self
.
_cells
=
fluid
.
dygraph
.
LayerList
(
cells
)
self
.
_cells
=
fluid
.
dygraph
.
LayerList
(
cells
)
k
=
sum
(
1
for
i
in
range
(
self
.
_steps
)
for
n
in
range
(
2
+
i
))
k
=
sum
(
1
for
i
in
range
(
self
.
_steps
)
for
n
in
range
(
2
+
i
))
num_ops
=
len
(
PRIMITIVES
)
num_ops
=
self
.
_n_ops
self
.
alphas
=
fluid
.
layers
.
create_parameter
(
self
.
alphas
=
fluid
.
layers
.
create_parameter
(
shape
=
[
k
,
num_ops
],
shape
=
[
k
,
num_ops
],
dtype
=
"float32"
,
dtype
=
"float32"
,
...
@@ -225,14 +343,11 @@ class EncoderLayer(Layer):
...
@@ -225,14 +343,11 @@ class EncoderLayer(Layer):
loc
=
0.0
,
scale
=
1e-3
))
loc
=
0.0
,
scale
=
1e-3
))
def
forward
(
self
,
enc_input
,
flops
=
[],
model_size
=
[]):
def
forward
(
self
,
enc_input
,
flops
=
[],
model_size
=
[]):
"""
tmp
=
fluid
.
layers
.
reshape
(
forward
enc_input
,
[
-
1
,
1
,
enc_input
.
shape
[
1
],
:param enc_input:
self
.
_hidden_size
])
#(bs, 1, seq_len, hidden_size)
:param attn_bias:
:return:
tmp
=
self
.
conv0
(
tmp
)
# (bs, hidden_size, seq_len, 1)
"""
tmp
=
fluid
.
layers
.
reshape
(
enc_input
,
[
-
1
,
1
,
enc_input
.
shape
[
1
],
self
.
_d_model
])
alphas
=
gumbel_softmax
(
self
.
alphas
)
alphas
=
gumbel_softmax
(
self
.
alphas
)
k
=
fluid
.
layers
.
reshape
(
gumbel_softmax
(
self
.
k
),
[
-
1
])
k
=
fluid
.
layers
.
reshape
(
gumbel_softmax
(
self
.
k
),
[
-
1
])
...
@@ -240,15 +355,16 @@ class EncoderLayer(Layer):
...
@@ -240,15 +355,16 @@ class EncoderLayer(Layer):
outputs
=
[]
outputs
=
[]
s0
=
s1
=
tmp
s0
=
s1
=
tmp
for
i
in
range
(
self
.
_n_layer
):
for
i
in
range
(
self
.
_n_layer
):
s0
,
s1
=
s1
,
self
.
_cells
[
i
](
s0
,
s0
,
s1
=
s1
,
self
.
_cells
[
i
](
s1
,
s0
,
s1
,
alphas
,
flops
=
flops
,
alphas
,
model_size
=
model_size
)
# (bs, hidden_size, seq_len, 1)
flops
=
flops
,
enc_output
=
fluid
.
layers
.
transpose
(
model_size
=
model_size
)
s1
,
[
0
,
2
,
1
,
3
])
# (bs, seq_len, hidden_size, 1
)
enc_output
=
fluid
.
layers
.
reshape
(
enc_output
=
fluid
.
layers
.
reshape
(
s1
,
[
-
1
,
enc_input
.
shape
[
1
],
self
.
_d_model
])
enc_output
,
[
-
1
,
enc_output
.
shape
[
1
],
self
.
_hidden_size
])
# (bs, seq_len, hidden_size)
outputs
.
append
(
enc_output
)
outputs
.
append
(
enc_output
)
if
k
[
i
].
numpy
()
!=
0
:
if
self
.
_search_layer
and
k
[
i
].
numpy
()
!=
0
:
outputs
[
-
1
]
=
outputs
[
-
1
]
*
k
[
i
]
outputs
[
-
1
]
=
outputs
[
-
1
]
*
k
[
i
]
return
outputs
,
k
[
i
]
return
outputs
,
k
[
i
]
return
None
return
outputs
,
1.0
paddleslim/nas/darts/train_search.py
浏览文件 @
d3e51bfc
...
@@ -74,6 +74,9 @@ class DARTSearch(object):
...
@@ -74,6 +74,9 @@ class DARTSearch(object):
def
train_one_epoch
(
self
,
train_loader
,
valid_loader
,
architect
,
optimizer
,
def
train_one_epoch
(
self
,
train_loader
,
valid_loader
,
architect
,
optimizer
,
epoch
):
epoch
):
objs
=
AvgrageMeter
()
objs
=
AvgrageMeter
()
ce_losses
=
AvgrageMeter
()
kd_losses
=
AvgrageMeter
()
e_losses
=
AvgrageMeter
()
self
.
model
.
train
()
self
.
model
.
train
()
step_id
=
0
step_id
=
0
...
@@ -81,7 +84,7 @@ class DARTSearch(object):
...
@@ -81,7 +84,7 @@ class DARTSearch(object):
if
epoch
>=
self
.
epochs_no_archopt
:
if
epoch
>=
self
.
epochs_no_archopt
:
architect
.
step
(
train_data
,
valid_data
)
architect
.
step
(
train_data
,
valid_data
)
loss
=
self
.
model
.
loss
(
train_data
)
loss
,
ce_loss
,
kd_loss
,
e_loss
=
self
.
model
.
loss
(
train_data
)
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
loss
=
self
.
model
.
scale_loss
(
loss
)
loss
=
self
.
model
.
scale_loss
(
loss
)
...
@@ -96,10 +99,18 @@ class DARTSearch(object):
...
@@ -96,10 +99,18 @@ class DARTSearch(object):
batch_size
=
train_data
[
0
].
shape
[
0
]
batch_size
=
train_data
[
0
].
shape
[
0
]
objs
.
update
(
loss
.
numpy
(),
batch_size
)
objs
.
update
(
loss
.
numpy
(),
batch_size
)
ce_losses
.
update
(
ce_loss
.
numpy
(),
batch_size
)
kd_losses
.
update
(
kd_loss
.
numpy
(),
batch_size
)
e_losses
.
update
(
e_loss
.
numpy
(),
batch_size
)
if
step_id
%
self
.
log_freq
==
0
:
if
step_id
%
self
.
log_freq
==
0
:
logger
.
info
(
"Train Epoch {}, Step {}, loss {:.6f}"
.
format
(
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
epoch
,
step_id
,
objs
.
avg
[
0
]))
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger
.
info
(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}"
.
format
(
epoch
,
step_id
,
loss
.
numpy
(),
ce_loss
.
numpy
(),
kd_loss
.
numpy
(),
e_loss
.
numpy
()))
step_id
+=
1
step_id
+=
1
return
objs
.
avg
[
0
]
return
objs
.
avg
[
0
]
...
...
paddleslim/teachers/bert/reader/cls.py
浏览文件 @
d3e51bfc
...
@@ -144,6 +144,14 @@ class DataProcessor(object):
...
@@ -144,6 +144,14 @@ class DataProcessor(object):
elif
phase
==
'test'
:
elif
phase
==
'test'
:
examples
=
self
.
get_test_examples
(
self
.
data_dir
)
examples
=
self
.
get_test_examples
(
self
.
data_dir
)
self
.
num_examples
[
'test'
]
=
len
(
examples
)
self
.
num_examples
[
'test'
]
=
len
(
examples
)
elif
phase
==
'search_train'
:
examples
=
self
.
get_train_examples
(
self
.
data_dir
)
self
.
num_examples
[
'search_train'
]
=
len
(
examples
)
/
2
examples
=
examples
[:
self
.
num_examples
[
'search_train'
]]
elif
phase
==
'search_valid'
:
examples
=
self
.
get_train_examples
(
self
.
data_dir
)
self
.
num_examples
[
'search_valid'
]
=
len
(
examples
)
/
2
examples
=
examples
[
self
.
num_examples
[
'search_train'
]:]
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
...
@@ -154,10 +162,10 @@ class DataProcessor(object):
...
@@ -154,10 +162,10 @@ class DataProcessor(object):
if
shuffle_seed
is
not
None
:
if
shuffle_seed
is
not
None
:
np
.
random
.
seed
(
shuffle_seed
)
np
.
random
.
seed
(
shuffle_seed
)
np
.
random
.
shuffle
(
examples
)
np
.
random
.
shuffle
(
examples
)
if
phase
==
'train'
:
if
phase
==
'train'
or
phase
==
'search_train'
:
self
.
current_train_epoch
=
epoch_index
self
.
current_train_epoch
=
epoch_index
for
(
index
,
example
)
in
enumerate
(
examples
):
for
(
index
,
example
)
in
enumerate
(
examples
):
if
phase
==
'train'
:
if
phase
==
'train'
or
phase
==
"search_train"
:
self
.
current_train_example
=
index
+
1
self
.
current_train_example
=
index
+
1
feature
=
self
.
convert_example
(
feature
=
self
.
convert_example
(
index
,
example
,
index
,
example
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录