Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1756d084
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1756d084
编写于
7月 30, 2020
作者:
H
hanhuifeng2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tinybert script suit for gpu
上级
0a1fac92
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
272 addition
and
40 deletion
+272
-40
model_zoo/official/nlp/tinybert/README.md
model_zoo/official/nlp/tinybert/README.md
+2
-2
model_zoo/official/nlp/tinybert/run_general_distill.py
model_zoo/official/nlp/tinybert/run_general_distill.py
+38
-15
model_zoo/official/nlp/tinybert/run_task_distill.py
model_zoo/official/nlp/tinybert/run_task_distill.py
+80
-16
model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh
...fficial/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh
+40
-0
model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh
model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh
+1
-1
model_zoo/official/nlp/tinybert/src/dataset.py
model_zoo/official/nlp/tinybert/src/dataset.py
+0
-3
model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py
model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py
+107
-2
model_zoo/official/nlp/tinybert/src/utils.py
model_zoo/official/nlp/tinybert/src/utils.py
+4
-1
未找到文件。
model_zoo/official/nlp/tinybert/README.md
浏览文件 @
1756d084
...
...
@@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T
options:
--distribute whether to run distributely: "true" | "false"
--device_target target
device to run, currently only support "Ascend
"
--device_target target
ed device to run task: "Ascend" | "GPU
"
--epoch_size epoch size: N, default is 1
--device_id device id: N, default is 0
--enable_data_sink enable data sink: "true" | "false", default is "true"
...
...
@@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T
options:
--distribute whether to run distributely: "true" | "false"
--device_target target
device to run, currently only support "Ascend
"
--device_target target
ed device to run task: "Ascend" | "GPU
"
--epoch_size epoch size: N, default is 1
--device_id device id: N, default is 0
--device_num device id to run task
...
...
model_zoo/official/nlp/tinybert/run_general_distill.py
浏览文件 @
1756d084
...
...
@@ -20,16 +20,20 @@ import argparse
import
datetime
import
numpy
import
mindspore.communication.management
as
D
import
mindspore.common.dtype
as
mstype
from
mindspore
import
context
from
mindspore.train.model
import
Model
from
mindspore.train.callback
import
TimeMonitor
from
mindspore.train.parallel_utils
import
ParallelMode
from
mindspore.nn.optim
import
AdamWeightDecay
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore
import
log
as
logger
from
src.dataset
import
create_tinybert_dataset
from
src.utils
import
LossCallBack
,
ModelSaveCkpt
,
BertLearningRate
from
src.gd_config
import
common_cfg
,
bert_teacher_net_cfg
,
bert_student_net_cfg
from
src.tinybert_for_gd_td
import
BertTrainWithLossScaleCell
,
BertNetworkWithLoss_gd
from
src.tinybert_for_gd_td
import
BertTrainWithLossScaleCell
,
BertNetworkWithLoss_gd
,
BertTrainCell
def
run_general_distill
():
"""
...
...
@@ -53,7 +57,6 @@ def run_general_distill():
parser
.
add_argument
(
"--schema_dir"
,
type
=
str
,
default
=
""
,
help
=
"Schema path, it is better to use absolute path"
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
reserve_class_name_in_scope
=
False
)
context
.
set_context
(
variable_memory_max_size
=
"30GB"
)
...
...
@@ -61,13 +64,17 @@ def run_general_distill():
save_ckpt_dir
=
os
.
path
.
join
(
args_opt
.
save_ckpt_path
,
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d_time_%H_%M_%S'
))
if
not
os
.
path
.
exists
(
save_ckpt_dir
):
os
.
makedirs
(
save_ckpt_dir
)
if
args_opt
.
distribute
==
"true"
:
D
.
init
(
'hccl'
)
device_num
=
args_opt
.
device_num
rank
=
args_opt
.
device_id
%
device_num
if
args_opt
.
device_target
==
'Ascend'
:
D
.
init
(
'hccl'
)
device_num
=
args_opt
.
device_num
rank
=
args_opt
.
device_id
%
device_num
else
:
D
.
init
(
'nccl'
)
device_num
=
D
.
get_group_size
()
rank
=
D
.
get_rank
()
save_ckpt_dir
=
save_ckpt_dir
+
'_ckpt_'
+
str
(
rank
)
context
.
reset_auto_parallel_context
()
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
device_num
)
...
...
@@ -75,6 +82,21 @@ def run_general_distill():
rank
=
0
device_num
=
1
if
not
os
.
path
.
exists
(
save_ckpt_dir
):
os
.
makedirs
(
save_ckpt_dir
)
enable_loss_scale
=
True
if
args_opt
.
device_target
==
"GPU"
:
if
bert_teacher_net_cfg
.
compute_type
!=
mstype
.
float32
:
logger
.
warning
(
'GPU only support fp32 temporarily, run with fp32.'
)
bert_teacher_net_cfg
.
compute_type
=
mstype
.
float32
if
bert_student_net_cfg
.
compute_type
!=
mstype
.
float32
:
logger
.
warning
(
'GPU only support fp32 temporarily, run with fp32.'
)
bert_student_net_cfg
.
compute_type
=
mstype
.
float32
# Both the forward and backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale
=
False
netwithloss
=
BertNetworkWithLoss_gd
(
teacher_config
=
bert_teacher_net_cfg
,
teacher_ckpt
=
args_opt
.
load_teacher_ckpt_path
,
student_config
=
bert_student_net_cfg
,
...
...
@@ -82,11 +104,11 @@ def run_general_distill():
dataset
=
create_tinybert_dataset
(
'gd'
,
bert_teacher_net_cfg
.
batch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
'dataset size: '
,
dataset_size
)
print
(
"dataset repeatcount: "
,
dataset
.
get_repeat_count
())
if
args_opt
.
enable_data_sink
==
"true"
:
repeat_count
=
args_opt
.
epoch_size
*
dataset
.
get_dataset_size
()
//
args_opt
.
data_sink_steps
repeat_count
=
args_opt
.
epoch_size
*
dataset
_size
//
args_opt
.
data_sink_steps
time_monitor_steps
=
args_opt
.
data_sink_steps
else
:
repeat_count
=
args_opt
.
epoch_size
...
...
@@ -110,12 +132,13 @@ def run_general_distill():
args_opt
.
save_ckpt_step
,
args_opt
.
max_ckpt_num
,
save_ckpt_dir
)]
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
common_cfg
.
loss_scale_value
,
scale_factor
=
common_cfg
.
scale_factor
,
scale_window
=
common_cfg
.
scale_window
)
netwithgrads
=
BertTrainWithLossScaleCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
if
enable_loss_scale
:
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
common_cfg
.
loss_scale_value
,
scale_factor
=
common_cfg
.
scale_factor
,
scale_window
=
common_cfg
.
scale_window
)
netwithgrads
=
BertTrainWithLossScaleCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
netwithgrads
=
BertTrainCell
(
netwithloss
,
optimizer
=
optimizer
)
model
=
Model
(
netwithgrads
)
model
.
train
(
repeat_count
,
dataset
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
"true"
),
...
...
model_zoo/official/nlp/tinybert/run_task_distill.py
浏览文件 @
1756d084
...
...
@@ -18,6 +18,7 @@
import
os
import
re
import
argparse
import
mindspore.common.dtype
as
mstype
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.train.model
import
Model
...
...
@@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore.nn.optim
import
AdamWeightDecay
from
mindspore
import
log
as
logger
from
src.dataset
import
create_tinybert_dataset
from
src.utils
import
LossCallBack
,
ModelSaveCkpt
,
EvalCallBack
,
BertLearningRate
from
src.assessment_method
import
Accuracy
from
src.td_config
import
phase1_cfg
,
phase2_cfg
,
td_teacher_net_cfg
,
td_student_net_cfg
from
src.tinybert_for_gd_td
import
BertEvaluation
Cell
,
BertNetworkWithLoss_td
from
src.tinybert_for_gd_td
import
BertEvaluation
WithLossScaleCell
,
BertNetworkWithLoss_td
,
BertEvaluationCell
from
src.tinybert_model
import
BertModelCLS
_cur_dir
=
os
.
getcwd
()
...
...
@@ -45,14 +47,14 @@ def parse_args():
parse args
"""
parser
=
argparse
.
ArgumentParser
(
description
=
'tinybert task distill'
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
,
help
=
"NPU device, default is Ascend."
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented. (Default: Ascend)'
)
parser
.
add_argument
(
"--do_train"
,
type
=
str
,
default
=
"true"
,
help
=
"Do train task, default is true."
)
parser
.
add_argument
(
"--do_eval"
,
type
=
str
,
default
=
"true"
,
help
=
"Do eval task, default is true."
)
parser
.
add_argument
(
"--td_phase1_epoch_size"
,
type
=
int
,
default
=
10
,
help
=
"Epoch size for td phase 1, default is 10."
)
parser
.
add_argument
(
"--td_phase2_epoch_size"
,
type
=
int
,
default
=
3
,
help
=
"Epoch size for td phase 2, default is 3."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--num_labels"
,
type
=
int
,
default
=
2
,
help
=
"Classfication task, support SST2, QNLI, MNLI."
)
parser
.
add_argument
(
"--do_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable shuffle for dataset, default is true."
)
parser
.
add_argument
(
"--enable_data_sink"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable data sink, default is true."
)
parser
.
add_argument
(
"--save_ckpt_step"
,
type
=
int
,
default
=
100
,
help
=
"Enable data sink, default is true."
)
...
...
@@ -64,11 +66,43 @@ def parse_args():
parser
.
add_argument
(
"--train_data_dir"
,
type
=
str
,
default
=
""
,
help
=
"Data path, it is better to use absolute path"
)
parser
.
add_argument
(
"--eval_data_dir"
,
type
=
str
,
default
=
""
,
help
=
"Data path, it is better to use absolute path"
)
parser
.
add_argument
(
"--schema_dir"
,
type
=
str
,
default
=
""
,
help
=
"Schema path, it is better to use absolute path"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
default
=
""
,
choices
=
[
"SST-2"
,
"QNLI"
,
"MNLI"
],
help
=
"The name of the task to train."
)
args
=
parser
.
parse_args
()
return
args
args_opt
=
parse_args
()
DEFAULT_NUM_LABELS
=
2
DEFAULT_SEQ_LENGTH
=
128
task_params
=
{
"SST-2"
:
{
"num_labels"
:
2
,
"seq_length"
:
64
},
"QNLI"
:
{
"num_labels"
:
2
,
"seq_length"
:
128
},
"MNLI"
:
{
"num_labels"
:
3
,
"seq_length"
:
128
}}
class
Task
:
"""
Encapsulation class of get the task parameter.
"""
def
__init__
(
self
,
task_name
):
self
.
task_name
=
task_name
@
property
def
num_labels
(
self
):
if
self
.
task_name
in
task_params
and
"num_labels"
in
task_params
[
self
.
task_name
]:
return
task_params
[
self
.
task_name
][
"num_labels"
]
return
DEFAULT_NUM_LABELS
@
property
def
seq_length
(
self
):
if
self
.
task_name
in
task_params
and
"seq_length"
in
task_params
[
self
.
task_name
]:
return
task_params
[
self
.
task_name
][
"seq_length"
]
return
DEFAULT_SEQ_LENGTH
task
=
Task
(
args_opt
.
task_name
)
def
run_predistill
():
"""
run predistill
...
...
@@ -81,7 +115,7 @@ def run_predistill():
netwithloss
=
BertNetworkWithLoss_td
(
teacher_config
=
td_teacher_net_cfg
,
teacher_ckpt
=
load_teacher_checkpoint_path
,
student_config
=
td_student_net_cfg
,
student_ckpt
=
load_student_checkpoint_path
,
is_training
=
True
,
task_type
=
'classification'
,
num_labels
=
args_opt
.
num_labels
,
is_predistill
=
True
)
num_labels
=
task
.
num_labels
,
is_predistill
=
True
)
rank
=
0
device_num
=
1
...
...
@@ -91,8 +125,9 @@ def run_predistill():
dataset_size
=
dataset
.
get_dataset_size
()
print
(
'td1 dataset size: '
,
dataset_size
)
print
(
'td1 dataset repeatcount: '
,
dataset
.
get_repeat_count
())
if
args_opt
.
enable_data_sink
==
'true'
:
repeat_count
=
args_opt
.
td_phase1_epoch_size
*
dataset
.
get_dataset_size
()
//
args_opt
.
data_sink_steps
repeat_count
=
args_opt
.
td_phase1_epoch_size
*
dataset
_size
//
args_opt
.
data_sink_steps
time_monitor_steps
=
args_opt
.
data_sink_steps
else
:
repeat_count
=
args_opt
.
td_phase1_epoch_size
...
...
@@ -117,10 +152,14 @@ def run_predistill():
args_opt
.
save_ckpt_step
,
args_opt
.
max_ckpt_num
,
td_phase1_save_ckpt_dir
)]
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
netwithgrads
=
BertEvaluationCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
if
enable_loss_scale
:
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
netwithgrads
=
BertEvaluationWithLossScaleCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
netwithgrads
=
BertEvaluationCell
(
netwithloss
,
optimizer
=
optimizer
)
model
=
Model
(
netwithgrads
)
model
.
train
(
repeat_count
,
dataset
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
'true'
),
...
...
@@ -139,7 +178,7 @@ def run_task_distill(ckpt_file):
netwithloss
=
BertNetworkWithLoss_td
(
teacher_config
=
td_teacher_net_cfg
,
teacher_ckpt
=
load_teacher_checkpoint_path
,
student_config
=
td_student_net_cfg
,
student_ckpt
=
load_student_checkpoint_path
,
is_training
=
True
,
task_type
=
'classification'
,
num_labels
=
args_opt
.
num_labels
,
is_predistill
=
False
)
num_labels
=
task
.
num_labels
,
is_predistill
=
False
)
rank
=
0
device_num
=
1
...
...
@@ -149,6 +188,7 @@ def run_task_distill(ckpt_file):
dataset_size
=
train_dataset
.
get_dataset_size
()
print
(
'td2 train dataset size: '
,
dataset_size
)
print
(
'td2 train dataset repeatcount: '
,
train_dataset
.
get_repeat_count
())
if
args_opt
.
enable_data_sink
==
'true'
:
repeat_count
=
args_opt
.
td_phase2_epoch_size
*
train_dataset
.
get_dataset_size
()
//
args_opt
.
data_sink_steps
time_monitor_steps
=
args_opt
.
data_sink_steps
...
...
@@ -175,6 +215,7 @@ def run_task_distill(ckpt_file):
eval_dataset
=
create_tinybert_dataset
(
'td'
,
td_teacher_net_cfg
.
batch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
eval_data_dir
,
args_opt
.
schema_dir
)
print
(
'td2 eval dataset size: '
,
eval_dataset
.
get_dataset_size
())
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
callback
=
[
TimeMonitor
(
time_monitor_steps
),
LossCallBack
(),
...
...
@@ -185,11 +226,14 @@ def run_task_distill(ckpt_file):
args_opt
.
save_ckpt_step
,
args_opt
.
max_ckpt_num
,
td_phase2_save_ckpt_dir
)]
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
if
enable_loss_scale
:
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
netwithgrads
=
BertEvaluationCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
netwithgrads
=
BertEvaluationWithLossScaleCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
netwithgrads
=
BertEvaluationCell
(
netwithloss
,
optimizer
=
optimizer
)
model
=
Model
(
netwithgrads
)
model
.
train
(
repeat_count
,
train_dataset
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
'true'
),
...
...
@@ -203,7 +247,7 @@ def do_eval_standalone():
if
ckpt_file
==
''
:
raise
ValueError
(
"Student ckpt file should not be None"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
eval_model
=
BertModelCLS
(
td_student_net_cfg
,
False
,
args_opt
.
num_labels
,
0.0
,
phase_type
=
"student"
)
eval_model
=
BertModelCLS
(
td_student_net_cfg
,
False
,
task
.
num_labels
,
0.0
,
phase_type
=
"student"
)
param_dict
=
load_checkpoint
(
ckpt_file
)
new_param_dict
=
{}
for
key
,
value
in
param_dict
.
items
():
...
...
@@ -213,10 +257,13 @@ def do_eval_standalone():
load_param_into_net
(
eval_model
,
new_param_dict
)
eval_model
.
set_train
(
False
)
eval_dataset
=
create_tinybert_dataset
(
'td'
,
batch_size
=
1
,
eval_dataset
=
create_tinybert_dataset
(
'td'
,
batch_size
=
td_student_net_cfg
.
batch_size
,
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"false"
,
data_dir
=
args_opt
.
eval_data_dir
,
schema_dir
=
args_opt
.
schema_dir
)
print
(
'eval dataset size: '
,
eval_dataset
.
get_dataset_size
())
print
(
'eval dataset batch size: '
,
eval_dataset
.
get_batch_size
())
callback
=
Accuracy
()
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
]
for
data
in
eval_dataset
.
create_dict_iterator
():
...
...
@@ -231,9 +278,26 @@ def do_eval_standalone():
print
(
"============== acc is {}"
.
format
(
acc
))
print
(
"======================================"
)
if
__name__
==
'__main__'
:
if
args_opt
.
do_train
.
lower
()
!=
"true"
and
args_opt
.
do_eval
.
lower
()
!=
"true"
:
raise
ValueError
(
"do_train or do eval must have one be true, please confirm your config"
)
enable_loss_scale
=
True
if
args_opt
.
device_target
==
"GPU"
:
if
td_teacher_net_cfg
.
compute_type
!=
mstype
.
float32
:
logger
.
warning
(
'GPU only support fp32 temporarily, run with fp32.'
)
td_teacher_net_cfg
.
compute_type
=
mstype
.
float32
if
td_student_net_cfg
.
compute_type
!=
mstype
.
float32
:
logger
.
warning
(
'GPU only support fp32 temporarily, run with fp32.'
)
td_student_net_cfg
.
compute_type
=
mstype
.
float32
# Both the forward and backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale
=
False
td_teacher_net_cfg
.
seq_length
=
task
.
seq_length
td_student_net_cfg
.
seq_length
=
task
.
seq_length
if
args_opt
.
do_train
==
"true"
:
# run predistill
run_predistill
()
...
...
model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh
0 → 100644
浏览文件 @
1756d084
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH"
echo
"for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt"
echo
"It is better to use absolute path."
echo
"=============================================================================================================="
RANK_SIZE
=
$1
EPOCH_SIZE
=
$2
DATA_DIR
=
$3
SCHEMA_DIR
=
$4
TEACHER_CKPT_PATH
=
$5
PROJECT_DIR
=
$(
cd
"
$(
dirname
"
$0
"
)
"
||
exit
;
pwd
)
mpirun
--allow-run-as-root
-n
$RANK_SIZE
\
python
${
PROJECT_DIR
}
/../run_general_distill.py
\
--distribute
=
"true"
\
--device_target
=
"GPU"
\
--epoch_size
=
$EPOCH_SIZE
\
--save_ckpt_path
=
""
\
--data_dir
=
$DATA_DIR
\
--schema_dir
=
$SCHEMA_DIR
\
--load_teacher_ckpt_path
=
$TEACHER_CKPT_PATH
>
log.txt 2>&1 &
model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh
浏览文件 @
1756d084
...
...
@@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--do_eval
=
"true"
\
--td_phase1_epoch_size
=
10
\
--td_phase2_epoch_size
=
3
\
--
num_labels
=
2
\
--
task_name
=
""
\
--do_shuffle
=
"true"
\
--enable_data_sink
=
"true"
\
--data_sink_steps
=
100
\
...
...
model_zoo/official/nlp/tinybert/src/dataset.py
浏览文件 @
1756d084
...
...
@@ -19,7 +19,6 @@ import os
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset.engine.datasets
as
de
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore
import
log
as
logger
def
create_tinybert_dataset
(
task
=
'td'
,
batch_size
=
32
,
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
data_dir
=
None
,
schema_dir
=
None
):
...
...
@@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
ds
=
ds
.
map
(
input_columns
=
"label_ids"
,
operations
=
type_cast_op
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeatcount: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py
浏览文件 @
1756d084
...
...
@@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell):
ret
=
(
loss
,
cond
,
scaling_sens
)
return
F
.
depend
(
ret
,
succ
)
class
BertTrainCell
(
nn
.
Cell
):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
):
super
(
BertTrainCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
self
.
sens
=
sens
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
reducer_flag
=
False
self
.
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
if
self
.
parallel_mode
in
[
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
]:
self
.
reducer_flag
=
True
self
.
grad_reducer
=
F
.
identity
self
.
degree
=
1
if
self
.
reducer_flag
:
mean
=
context
.
get_auto_parallel_context
(
"mirror_mean"
)
self
.
degree
=
get_group_size
()
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
self
.
degree
)
self
.
cast
=
P
.
Cast
()
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
):
"""Defines the computation performed."""
weights
=
self
.
weights
loss
=
self
.
network
(
input_ids
,
input_mask
,
token_type_id
)
grads
=
self
.
grad
(
self
.
network
,
weights
)(
input_ids
,
input_mask
,
token_type_id
,
self
.
cast
(
F
.
tuple_to_array
((
self
.
sens
,)),
mstype
.
float32
))
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
succ
=
self
.
optimizer
(
grads
)
return
F
.
depend
(
loss
,
succ
)
class
BertNetworkWithLoss_td
(
nn
.
Cell
):
"""
Provide bert pre-training loss through network.
...
...
@@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell):
total_loss
+=
cls_loss
return
self
.
cast
(
total_loss
,
mstype
.
float32
)
class
BertEvaluationCell
(
nn
.
Cell
):
class
BertEvaluation
WithLossScale
Cell
(
nn
.
Cell
):
"""
Especifically defined for finetuning where only four inputs tensor are needed.
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
):
super
(
BertEvaluationCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
BertEvaluation
WithLossScale
Cell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
...
...
@@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell):
succ
=
self
.
optimizer
(
grads
)
ret
=
(
loss
,
cond
,
scaling_sens
)
return
F
.
depend
(
ret
,
succ
)
class
BertEvaluationCell
(
nn
.
Cell
):
"""
Especifically defined for finetuning where only four inputs tensor are needed.
"""
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
):
super
(
BertEvaluationCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
self
.
sens
=
sens
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
reducer_flag
=
False
self
.
parallel_mode
=
context
.
get_auto_parallel_context
(
"parallel_mode"
)
if
self
.
parallel_mode
in
[
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
]:
self
.
reducer_flag
=
True
self
.
grad_reducer
=
F
.
identity
self
.
degree
=
1
if
self
.
reducer_flag
:
mean
=
context
.
get_auto_parallel_context
(
"mirror_mean"
)
self
.
degree
=
get_group_size
()
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
self
.
degree
)
self
.
is_distributed
=
(
self
.
parallel_mode
!=
ParallelMode
.
STAND_ALONE
)
self
.
cast
=
P
.
Cast
()
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
label_ids
):
"""Defines the computation performed."""
weights
=
self
.
weights
loss
=
self
.
network
(
input_ids
,
input_mask
,
token_type_id
,
label_ids
)
grads
=
self
.
grad
(
self
.
network
,
weights
)(
input_ids
,
input_mask
,
token_type_id
,
label_ids
,
self
.
cast
(
F
.
tuple_to_array
((
self
.
sens
,)),
mstype
.
float32
))
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
clip_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_VALUE
),
grads
)
succ
=
self
.
optimizer
(
grads
)
return
F
.
depend
(
loss
,
succ
)
model_zoo/official/nlp/tinybert/src/utils.py
浏览文件 @
1756d084
...
...
@@ -110,7 +110,10 @@ class EvalCallBack(Callback):
if
acc
>
self
.
global_acc
:
self
.
global_acc
=
acc
print
(
"The best acc is {}"
.
format
(
acc
))
_exec_save_checkpoint
(
self
.
network
,
"eval_model.ckpt"
)
eval_model_ckpt_file
=
"eval_model.ckpt"
if
os
.
path
.
exists
(
eval_model_ckpt_file
):
os
.
remove
(
eval_model_ckpt_file
)
_exec_save_checkpoint
(
self
.
network
,
eval_model_ckpt_file
)
class
BertLearningRate
(
LearningRateSchedule
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录