Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
343c64cc
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
343c64cc
编写于
4月 15, 2020
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update bert DARTS
上级
649ffd9e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
119 addition
and
390 deletion
+119
-390
paddleslim/nas/darts/search_space/conv_bert/cls.py
paddleslim/nas/darts/search_space/conv_bert/cls.py
+82
-201
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
+24
-155
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
+1
-0
paddleslim/nas/darts/train_search.py
paddleslim/nas/darts/train_search.py
+12
-34
未找到文件。
paddleslim/nas/darts/search_space/conv_bert/cls.py
浏览文件 @
343c64cc
...
@@ -30,207 +30,88 @@ import numpy as np
...
@@ -30,207 +30,88 @@ import numpy as np
import
multiprocessing
import
multiprocessing
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
to_variable
,
Layer
from
paddle.fluid.dygraph
import
to_variable
,
Layer
,
Linear
from
.reader.cls
import
*
from
.reader.cls
import
*
from
.model.bert
import
BertConfig
from
.model.bert
import
BertModelLayer
from
.model.cls
import
ClsModelLayer
from
.optimization
import
Optimizer
from
.optimization
import
Optimizer
from
.utils.init
import
init_from_static_model
from
.utils.init
import
init_from_static_model
from
paddleslim.teachers.bert
import
BERTClassifier
__all__
=
[
"ConvBERTClassifier"
]
__all__
=
[
"AdaBERTClassifier"
]
def
create_data
(
batch
):
"""
class
AdaBERTClassifier
(
Layer
):
convert data to variable
def
__init__
(
self
,
num_labels
,
n_layer
=
12
,
emb_size
=
768
):
"""
super
(
AdaBERTClassifier
,
self
).
__init__
()
src_ids
=
to_variable
(
batch
[
0
],
"src_ids"
)
self
.
_n_layer
=
n_layer
position_ids
=
to_variable
(
batch
[
1
],
"position_ids"
)
self
.
_num_labels
=
num_labels
sentence_ids
=
to_variable
(
batch
[
2
],
"sentence_ids"
)
self
.
_emb_size
=
emb_size
input_mask
=
to_variable
(
batch
[
3
],
"input_mask"
)
self
.
teacher
=
BERTClassifier
(
num_labels
)
labels
=
to_variable
(
batch
[
4
],
"labels"
)
self
.
student
=
BertModelLayer
(
labels
.
stop_gradient
=
True
n_layer
=
self
.
_n_layer
,
emb_size
=
self
.
_emb_size
)
return
src_ids
,
position_ids
,
sentence_ids
,
input_mask
,
labels
self
.
cls_fc
=
list
()
for
i
in
range
(
self
.
_n_layer
):
class
ConvBERTClassifier
(
Layer
):
fc
=
Linear
(
def
__init__
(
self
,
input_dim
=
self
.
_emb_size
,
num_labels
,
output_dim
=
self
.
_num_labels
,
task_name
=
"mnli"
,
param_attr
=
fluid
.
ParamAttr
(
model_path
=
None
,
name
=
"s_cls_out_%d_w"
%
i
,
use_cuda
=
True
):
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
super
(
ConvBERTClassifier
,
self
).
__init__
()
bias_attr
=
fluid
.
ParamAttr
(
self
.
task_name
=
task_name
.
lower
()
name
=
"s_cls_out_%d_b"
%
i
,
BERT_BASE_PATH
=
"./data/pretrained_models/uncased_L-12_H-768_A-12/"
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
bert_config_path
=
BERT_BASE_PATH
+
"/bert_config.json"
fc
=
self
.
add_sublayer
(
"cls_fc_%d"
%
i
,
fc
)
self
.
vocab_path
=
BERT_BASE_PATH
+
"/vocab.txt"
self
.
cls_fc
.
append
(
fc
)
self
.
init_pretraining_params
=
BERT_BASE_PATH
+
"/dygraph_params/"
self
.
do_lower_case
=
True
def
forward
(
self
,
data_ids
):
self
.
bert_config
=
BertConfig
(
bert_config_path
)
src_ids
=
data_ids
[
0
]
position_ids
=
data_ids
[
1
]
if
use_cuda
:
sentence_ids
=
data_ids
[
2
]
self
.
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
return
self
.
student
(
src_ids
,
position_ids
,
sentence_ids
)
else
:
self
.
dev_count
=
int
(
def
arch_parameters
(
self
):
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
return
self
.
student
.
arch_parameters
()
self
.
trainer_count
=
fluid
.
dygraph
.
parallel
.
Env
().
nranks
def
genotype
(
self
):
return
self
.
arch_parameters
()
self
.
processors
=
{
'xnli'
:
XnliProcessor
,
def
loss
(
self
,
data_ids
,
beta
=
0.5
,
gamma
=
0.5
):
'cola'
:
ColaProcessor
,
T
=
1.0
'mrpc'
:
MrpcProcessor
,
src_ids
=
data_ids
[
0
]
'mnli'
:
MnliProcessor
,
position_ids
=
data_ids
[
1
]
}
sentence_ids
=
data_ids
[
2
]
input_mask
=
data_ids
[
3
]
self
.
cls_model
=
ClsModelLayer
(
labels
=
data_ids
[
4
]
self
.
bert_config
,
num_labels
,
return_pooled_out
=
True
)
enc_outputs
,
next_sent_feats
=
self
.
student
(
src_ids
,
position_ids
,
sentence_ids
)
if
model_path
is
not
None
:
#restore the model
self
.
teacher
.
eval
()
print
(
"Load params from %s"
%
model_path
)
total_loss
,
logits
,
losses
,
accuracys
,
num_seqs
=
self
.
teacher
(
model_dict
,
_
=
fluid
.
load_dygraph
(
model_path
)
data_ids
)
self
.
cls_model
.
load_dict
(
model_dict
)
kd_losses
=
[]
def
forward
(
self
,
input
):
for
t_logits
,
t_loss
,
s_sent_feat
,
fc
in
zip
(
return
self
.
cls_model
(
input
)
logits
,
losses
,
next_sent_feats
,
self
.
cls_fc
):
s_sent_feat
=
fluid
.
layers
.
dropout
(
def
test
(
self
,
data_dir
,
batch_size
=
64
,
max_seq_len
=
512
):
x
=
s_sent_feat
,
dropout_prob
=
0.1
,
processor
=
self
.
processors
[
self
.
task_name
](
dropout_implementation
=
"upscale_in_train"
)
data_dir
=
data_dir
,
s_logits
=
fc
(
s_sent_feat
)
vocab_path
=
self
.
vocab_path
,
max_seq_len
=
max_seq_len
,
t_probs
=
fluid
.
layers
.
softmax
(
t_logits
)
do_lower_case
=
self
.
do_lower_case
,
s_probs
=
fluid
.
layers
.
softmax
(
s_logits
)
in_tokens
=
False
)
kd_loss
=
t_probs
*
fluid
.
layers
.
log
(
s_probs
/
T
)
kd_loss
=
fluid
.
layers
.
reduce_sum
(
kd_loss
,
dim
=
1
)
test_data_generator
=
processor
.
data_generator
(
kd_loss
=
fluid
.
layers
.
reduce_mean
(
kd_loss
,
dim
=
0
)
batch_size
=
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
)
kd_loss
=
kd_loss
/
t_loss
kd_losses
.
append
(
kd_loss
)
self
.
cls_model
.
eval
()
total_cost
,
final_acc
,
avg_acc
,
total_num_seqs
=
[],
[],
[],
[]
kd_loss
=
fluid
.
layers
.
sum
(
kd_losses
)
for
batch
in
test_data_generator
():
data_ids
=
create_data
(
batch
)
ce_loss
=
fluid
.
layers
.
cross_entropy
(
s_probs
,
labels
)
ce_loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
total_loss
,
_
,
_
,
np_acces
,
np_num_seqs
=
self
.
cls_model
(
data_ids
)
e_loss
=
1
# to be done
np_loss
=
total_loss
.
numpy
()
loss
=
(
1
-
gamma
)
*
ce_loss
+
gamma
*
kd_loss
+
beta
*
e_loss
np_acc
=
np_acces
[
-
1
].
numpy
()
return
loss
np_avg_acc
=
np
.
mean
([
acc
.
numpy
()
for
acc
in
np_acces
])
np_num_seqs
=
np_num_seqs
.
numpy
()
total_cost
.
extend
(
np_loss
*
np_num_seqs
)
final_acc
.
extend
(
np_acc
*
np_num_seqs
)
avg_acc
.
extend
(
np_avg_acc
*
np_num_seqs
)
total_num_seqs
.
extend
(
np_num_seqs
)
print
(
"[evaluation] classifier[-1] average acc: %f; average acc: %f"
%
(
np
.
sum
(
final_acc
)
/
np
.
sum
(
total_num_seqs
),
np
.
sum
(
avg_acc
)
/
np
.
sum
(
total_num_seqs
)))
self
.
cls_model
.
train
()
def
fit
(
self
,
data_dir
,
epoch
,
batch_size
=
64
,
use_cuda
=
True
,
max_seq_len
=
512
,
warmup_proportion
=
0.1
,
use_data_parallel
=
False
,
learning_rate
=
0.00005
,
weight_decay
=
0.01
,
lr_scheduler
=
"linear_warmup_decay"
,
skip_steps
=
10
,
save_steps
=
1000
,
checkpoints
=
"checkpoints"
):
processor
=
self
.
processors
[
self
.
task_name
](
data_dir
=
data_dir
,
vocab_path
=
self
.
vocab_path
,
max_seq_len
=
max_seq_len
,
do_lower_case
=
self
.
do_lower_case
,
in_tokens
=
False
,
random_seed
=
5512
)
shuffle_seed
=
1
if
self
.
trainer_count
>
1
else
None
train_data_generator
=
processor
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'train'
,
epoch
=
epoch
,
dev_count
=
self
.
trainer_count
,
shuffle
=
True
,
shuffle_seed
=
shuffle_seed
)
num_train_examples
=
processor
.
get_num_examples
(
phase
=
'train'
)
max_train_steps
=
epoch
*
num_train_examples
//
batch_size
//
self
.
trainer_count
warmup_steps
=
int
(
max_train_steps
*
warmup_proportion
)
print
(
"Device count: %d"
%
self
.
dev_count
)
print
(
"Trainer count: %d"
%
self
.
trainer_count
)
print
(
"Num train examples: %d"
%
num_train_examples
)
print
(
"Max train steps: %d"
%
max_train_steps
)
print
(
"Num warmup steps: %d"
%
warmup_steps
)
if
use_data_parallel
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
optimizer
=
Optimizer
(
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
learning_rate
=
learning_rate
,
model_cls
=
self
.
cls_model
,
weight_decay
=
weight_decay
,
scheduler
=
lr_scheduler
,
loss_scaling
=
1.0
,
parameter_list
=
self
.
cls_model
.
parameters
())
if
use_data_parallel
:
self
.
cls_model
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
cls_model
,
strategy
)
train_data_generator
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
train_data_generator
)
steps
=
0
time_begin
=
time
.
time
()
for
batch
in
train_data_generator
():
data_ids
=
create_data
(
batch
)
total_loss
,
logits
,
losses
,
accuracys
,
num_seqs
=
self
.
cls_model
(
data_ids
)
optimizer
.
optimization
(
losses
[
-
1
],
use_data_parallel
=
use_data_parallel
,
model
=
self
.
cls_model
)
self
.
cls_model
.
clear_gradients
()
if
steps
!=
0
and
steps
%
skip_steps
==
0
:
time_end
=
time
.
time
()
used_time
=
time_end
-
time_begin
current_example
,
current_epoch
=
processor
.
get_train_progress
()
localtime
=
time
.
asctime
(
time
.
localtime
(
time
.
time
()))
print
(
"%s, epoch: %s, steps: %s, dy_graph loss: %f, acc: %f, speed: %f steps/s"
%
(
localtime
,
current_epoch
,
steps
,
total_loss
.
numpy
(),
accuracys
[
-
1
].
numpy
(),
skip_steps
/
used_time
))
time_begin
=
time
.
time
()
if
steps
!=
0
and
steps
%
save_steps
==
0
and
fluid
.
dygraph
.
parallel
.
Env
(
).
local_rank
==
0
:
self
.
test
(
data_dir
,
batch_size
=
64
,
max_seq_len
=
512
)
save_path
=
os
.
path
.
join
(
checkpoints
,
"steps"
+
"_"
+
str
(
steps
))
fluid
.
save_dygraph
(
self
.
cls_model
.
state_dict
(),
save_path
)
fluid
.
save_dygraph
(
optimizer
.
optimizer
.
state_dict
(),
save_path
)
print
(
"Save model parameters and optimizer status at %s"
%
save_path
)
steps
+=
1
if
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
:
save_path
=
os
.
path
.
join
(
checkpoints
,
"final"
)
fluid
.
save_dygraph
(
self
.
cls_model
.
state_dict
(),
save_path
)
fluid
.
save_dygraph
(
optimizer
.
optimizer
.
state_dict
(),
save_path
)
print
(
"Save model parameters and optimizer status at %s"
%
save_path
)
paddleslim/nas/darts/search_space/conv_bert/model/bert.py
浏览文件 @
343c64cc
...
@@ -20,62 +20,38 @@ from __future__ import print_function
...
@@ -20,62 +20,38 @@ from __future__ import print_function
import
six
import
six
import
json
import
json
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
to_variable
,
Layer
,
guard
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
to_variable
,
Layer
,
guard
from
.transformer_encoder
import
EncoderLayer
from
.transformer_encoder
import
EncoderLayer
class
BertConfig
(
object
):
def
__init__
(
self
,
config_path
):
self
.
_config_dict
=
self
.
_parse
(
config_path
)
def
_parse
(
self
,
config_path
):
try
:
with
open
(
config_path
)
as
json_file
:
config_dict
=
json
.
load
(
json_file
)
except
Exception
:
raise
IOError
(
"Error in parsing bert model config file '%s'"
%
config_path
)
else
:
return
config_dict
def
__getitem__
(
self
,
key
):
return
self
.
_config_dict
[
key
]
def
print_config
(
self
):
for
arg
,
value
in
sorted
(
six
.
iteritems
(
self
.
_config_dict
)):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
class
BertModelLayer
(
Layer
):
class
BertModelLayer
(
Layer
):
"""
def
__init__
(
self
,
bert
emb_size
=
768
,
"""
n_layer
=
12
,
voc_size
=
30522
,
def
__init__
(
self
,
config
,
return_pooled_out
=
True
,
use_fp16
=
False
):
max_position_seq_len
=
512
,
sent_types
=
2
,
return_pooled_out
=
True
,
initializer_range
=
1.0
,
use_fp16
=
False
):
super
(
BertModelLayer
,
self
).
__init__
()
super
(
BertModelLayer
,
self
).
__init__
()
self
.
_emb_size
=
config
[
'hidden_size'
]
self
.
_emb_size
=
emb_size
self
.
_n_layer
=
config
[
'num_hidden_layers'
]
self
.
_n_layer
=
n_layer
self
.
_n_head
=
config
[
'num_attention_heads'
]
self
.
_voc_size
=
voc_size
self
.
_voc_size
=
config
[
'vocab_size'
]
self
.
_max_position_seq_len
=
max_position_seq_len
self
.
_max_position_seq_len
=
config
[
'max_position_embeddings'
]
self
.
_sent_types
=
sent_types
self
.
_sent_types
=
config
[
'type_vocab_size'
]
self
.
_hidden_act
=
config
[
'hidden_act'
]
self
.
_prepostprocess_dropout
=
config
[
'hidden_dropout_prob'
]
self
.
_attention_dropout
=
config
[
'attention_probs_dropout_prob'
]
self
.
return_pooled_out
=
return_pooled_out
self
.
return_pooled_out
=
return_pooled_out
self
.
_word_emb_name
=
"word_embedding"
self
.
_word_emb_name
=
"
s_
word_embedding"
self
.
_pos_emb_name
=
"pos_embedding"
self
.
_pos_emb_name
=
"
s_
pos_embedding"
self
.
_sent_emb_name
=
"sent_embedding"
self
.
_sent_emb_name
=
"s
_s
ent_embedding"
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
config
[
'initializer_range'
]
)
scale
=
initializer_range
)
self
.
_src_emb
=
Embedding
(
self
.
_src_emb
=
Embedding
(
size
=
[
self
.
_voc_size
,
self
.
_emb_size
],
size
=
[
self
.
_voc_size
,
self
.
_emb_size
],
...
@@ -99,14 +75,17 @@ class BertModelLayer(Layer):
...
@@ -99,14 +75,17 @@ class BertModelLayer(Layer):
input_dim
=
self
.
_emb_size
,
input_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_emb_size
,
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
name
=
"pooled_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
name
=
"
s_
pooled_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"pooled_fc.b_0"
,
bias_attr
=
"
s_
pooled_fc.b_0"
,
act
=
"tanh"
)
act
=
"tanh"
)
self
.
_encoder
=
EncoderLayer
(
self
.
_encoder
=
EncoderLayer
(
n_layer
=
self
.
_n_layer
,
d_model
=
self
.
_emb_size
)
n_layer
=
self
.
_n_layer
,
d_model
=
self
.
_emb_size
)
def
forward
(
self
,
src_ids
,
position_ids
,
sentence_ids
,
input_mask
):
def
arch_parameters
(
self
):
return
[
self
.
_encoder
.
alphas
]
def
forward
(
self
,
src_ids
,
position_ids
,
sentence_ids
):
"""
"""
forward
forward
"""
"""
...
@@ -131,113 +110,3 @@ class BertModelLayer(Layer):
...
@@ -131,113 +110,3 @@ class BertModelLayer(Layer):
next_sent_feats
.
append
(
next_sent_feat
)
next_sent_feats
.
append
(
next_sent_feat
)
return
enc_outputs
,
next_sent_feats
return
enc_outputs
,
next_sent_feats
class
PretrainModelLayer
(
Layer
):
"""
pretrain model
"""
def
__init__
(
self
,
config
,
return_pooled_out
=
True
,
weight_sharing
=
True
,
use_fp16
=
False
):
super
(
PretrainModelLayer
,
self
).
__init__
()
self
.
config
=
config
self
.
_voc_size
=
config
[
'vocab_size'
]
self
.
_emb_size
=
config
[
'hidden_size'
]
self
.
_hidden_act
=
config
[
'hidden_act'
]
self
.
_prepostprocess_dropout
=
config
[
'hidden_dropout_prob'
]
self
.
_word_emb_name
=
"word_embedding"
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
config
[
'initializer_range'
])
self
.
_weight_sharing
=
weight_sharing
self
.
use_fp16
=
use_fp16
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
self
.
bert_layer
=
BertModelLayer
(
config
=
self
.
config
,
return_pooled_out
=
True
,
use_fp16
=
self
.
use_fp16
)
self
.
pre_process_layer
=
PrePostProcessLayer
(
"n"
,
self
.
_emb_size
,
self
.
_prepostprocess_dropout
,
"pre_encoder"
)
self
.
pooled_fc
=
Linear
(
input_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_emb_size
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"mask_lm_trans_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"mask_lm_trans_fc.b_0"
,
act
=
"tanh"
)
self
.
mask_lm_out_bias_attr
=
fluid
.
ParamAttr
(
name
=
"mask_lm_out_fc.b_0"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
))
if
not
self
.
_weight_sharing
:
self
.
out_fc
=
Linear
(
input_dim
=
self
.
_emb_size
,
output_dim
=
self
.
_voc_size
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"mask_lm_out_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
self
.
mask_lm_out_bias_attr
)
else
:
self
.
fc_create_params
=
self
.
create_parameter
(
shape
=
[
self
.
_voc_size
],
dtype
=
self
.
_dtype
,
attr
=
self
.
mask_lm_out_bias_attr
,
is_bias
=
True
)
self
.
next_sent_fc
=
Linear
(
input_dim
=
self
.
_emb_size
,
output_dim
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"next_sent_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"next_sent_fc.b_0"
)
def
forward
(
self
,
src_ids
,
position_ids
,
sentence_ids
,
input_mask
,
mask_label
,
mask_pos
,
labels
):
"""
forward
"""
mask_pos
=
fluid
.
layers
.
cast
(
x
=
mask_pos
,
dtype
=
'int32'
)
enc_output
,
next_sent_feat
=
self
.
bert_layer
(
src_ids
,
position_ids
,
sentence_ids
,
input_mask
)
reshaped_emb_out
=
fluid
.
layers
.
reshape
(
x
=
enc_output
,
shape
=
[
-
1
,
self
.
_emb_size
])
mask_feat
=
fluid
.
layers
.
gather
(
input
=
reshaped_emb_out
,
index
=
mask_pos
)
mask_trans_feat
=
self
.
pooled_fc
(
mask_feat
)
mask_trans_feat
=
self
.
pre_process_layer
(
None
,
mask_trans_feat
,
"n"
,
self
.
_prepostprocess_dropout
)
if
self
.
_weight_sharing
:
fc_out
=
fluid
.
layers
.
matmul
(
x
=
mask_trans_feat
,
y
=
self
.
bert_layer
.
_src_emb
.
_w
,
transpose_y
=
True
)
fc_out
+=
self
.
fc_create_params
else
:
fc_out
=
self
.
out_fc
(
mask_trans_feat
)
mask_lm_loss
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
fc_out
,
label
=
mask_label
)
mean_mask_lm_loss
=
fluid
.
layers
.
mean
(
mask_lm_loss
)
next_sent_fc_out
=
self
.
next_sent_fc
(
next_sent_feat
)
next_sent_loss
,
next_sent_softmax
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
next_sent_fc_out
,
label
=
labels
,
return_softmax
=
True
)
next_sent_acc
=
fluid
.
layers
.
accuracy
(
input
=
next_sent_softmax
,
label
=
labels
)
mean_next_sent_loss
=
fluid
.
layers
.
mean
(
next_sent_loss
)
loss
=
mean_next_sent_loss
+
mean_mask_lm_loss
return
next_sent_acc
,
mean_mask_lm_loss
,
loss
paddleslim/nas/darts/search_space/conv_bert/model/cls.py
浏览文件 @
343c64cc
...
@@ -26,6 +26,7 @@ import paddle.fluid as fluid
...
@@ -26,6 +26,7 @@ import paddle.fluid as fluid
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
.bert
import
BertModelLayer
from
.bert
import
BertModelLayer
from
paddleslim.teachers.bert
import
BERTClassifier
class
ClsModelLayer
(
Layer
):
class
ClsModelLayer
(
Layer
):
...
...
paddleslim/nas/darts/train_search.py
浏览文件 @
343c64cc
...
@@ -44,7 +44,7 @@ class DARTSearch(object):
...
@@ -44,7 +44,7 @@ class DARTSearch(object):
batchsize
=
64
,
batchsize
=
64
,
num_imgs
=
50000
,
num_imgs
=
50000
,
arch_learning_rate
=
3e-4
,
arch_learning_rate
=
3e-4
,
unrolled
=
'False'
,
unrolled
=
False
,
num_epochs
=
50
,
num_epochs
=
50
,
epochs_no_archopt
=
0
,
epochs_no_archopt
=
0
,
use_gpu
=
True
,
use_gpu
=
True
,
...
@@ -73,32 +73,16 @@ class DARTSearch(object):
...
@@ -73,32 +73,16 @@ class DARTSearch(object):
def
train_one_epoch
(
self
,
train_loader
,
valid_loader
,
architect
,
optimizer
,
def
train_one_epoch
(
self
,
train_loader
,
valid_loader
,
architect
,
optimizer
,
epoch
):
epoch
):
objs
=
AvgrageMeter
()
objs
=
AvgrageMeter
()
top1
=
AvgrageMeter
()
top5
=
AvgrageMeter
()
self
.
model
.
train
()
self
.
model
.
train
()
for
step_id
,
(
for
step_id
,
(
train_data
,
train_data
,
valid_data
)
in
enumerate
(
zip
(
train_loader
(),
valid_loader
())):
valid_data
)
in
enumerate
(
zip
(
train_loader
(),
valid_loader
())):
train_image
,
train_label
=
train_data
valid_image
,
valid_label
=
valid_data
train_image
=
to_variable
(
train_image
)
train_label
=
to_variable
(
train_label
)
train_label
.
stop_gradient
=
True
valid_image
=
to_variable
(
valid_image
)
valid_label
=
to_variable
(
valid_label
)
valid_label
.
stop_gradient
=
True
n
=
train_image
.
shape
[
0
]
if
epoch
>=
self
.
epochs_no_archopt
:
if
epoch
>=
self
.
epochs_no_archopt
:
architect
.
step
(
train_image
,
train_label
,
valid_image
,
architect
.
step
(
train_data
,
valid_data
)
valid_label
)
logits
=
self
.
model
(
train_image
)
loss
=
self
.
model
.
loss
(
train_data
)
prec1
=
fluid
.
layers
.
accuracy
(
input
=
logits
,
label
=
train_label
,
k
=
1
)
prec5
=
fluid
.
layers
.
accuracy
(
input
=
logits
,
label
=
train_label
,
k
=
5
)
loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
,
train_label
))
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
loss
=
self
.
model
.
scale_loss
(
loss
)
loss
=
self
.
model
.
scale_loss
(
loss
)
...
@@ -111,16 +95,12 @@ class DARTSearch(object):
...
@@ -111,16 +95,12 @@ class DARTSearch(object):
optimizer
.
minimize
(
loss
,
grad_clip
)
optimizer
.
minimize
(
loss
,
grad_clip
)
self
.
model
.
clear_gradients
()
self
.
model
.
clear_gradients
()
objs
.
update
(
loss
.
numpy
(),
n
)
objs
.
update
(
loss
.
numpy
(),
self
.
batchsize
)
top1
.
update
(
prec1
.
numpy
(),
n
)
top5
.
update
(
prec5
.
numpy
(),
n
)
if
step_id
%
self
.
log_freq
==
0
:
if
step_id
%
self
.
log_freq
==
0
:
logger
.
info
(
logger
.
info
(
"Train Epoch {}, Step {}, loss {:.6f}"
.
format
(
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}"
.
epoch
,
step_id
,
objs
.
avg
[
0
]))
format
(
epoch
,
step_id
,
objs
.
avg
[
0
],
top1
.
avg
[
0
],
top5
.
avg
[
return
objs
.
avg
[
0
]
0
]))
return
top1
.
avg
[
0
]
def
valid_one_epoch
(
self
,
valid_loader
,
epoch
):
def
valid_one_epoch
(
self
,
valid_loader
,
epoch
):
objs
=
AvgrageMeter
()
objs
=
AvgrageMeter
()
...
@@ -128,7 +108,7 @@ class DARTSearch(object):
...
@@ -128,7 +108,7 @@ class DARTSearch(object):
top5
=
AvgrageMeter
()
top5
=
AvgrageMeter
()
self
.
model
.
eval
()
self
.
model
.
eval
()
for
step_id
,
(
image
,
label
)
in
enumerate
(
valid_loader
):
for
step_id
,
valid_data
in
enumerate
(
valid_loader
):
image
=
to_variable
(
image
)
image
=
to_variable
(
image
)
label
=
to_variable
(
label
)
label
=
to_variable
(
label
)
n
=
image
.
shape
[
0
]
n
=
image
.
shape
[
0
]
...
@@ -204,13 +184,11 @@ class DARTSearch(object):
...
@@ -204,13 +184,11 @@ class DARTSearch(object):
genotype
=
self
.
model
.
genotype
()
genotype
=
self
.
model
.
genotype
()
logger
.
info
(
'genotype = %s'
,
genotype
)
logger
.
info
(
'genotype = %s'
,
genotype
)
train_top1
=
self
.
train_one_epoch
(
train_loader
,
valid_loader
,
self
.
train_one_epoch
(
train_loader
,
valid_loader
,
architect
,
architect
,
optimizer
,
epoch
)
optimizer
,
epoch
)
logger
.
info
(
"Epoch {}, train_acc {:.6f}"
.
format
(
epoch
,
train_top1
))
if
epoch
==
self
.
num_epochs
-
1
:
if
epoch
==
self
.
num_epochs
-
1
:
valid_top1
=
self
.
valid_one_epoch
(
valid_loader
,
epoch
)
# valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger
.
info
(
"Epoch {}, valid_acc {:.6f}"
.
format
(
epoch
,
logger
.
info
(
"Epoch {}, valid_acc {:.6f}"
.
format
(
epoch
,
1
))
valid_top1
))
if
save_parameters
:
if
save_parameters
:
fluid
.
save_dygraph
(
self
.
model
.
state_dict
(),
"./weights"
)
fluid
.
save_dygraph
(
self
.
model
.
state_dict
(),
"./weights"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录