Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4e0447b0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e0447b0
编写于
8月 20, 2020
作者:
D
dengyutao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support minddataset for tinybert
上级
b2cff284
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
40 addition
and
11 deletion
+40
-11
model_zoo/official/nlp/tinybert/run_general_distill.py
model_zoo/official/nlp/tinybert/run_general_distill.py
+10
-2
model_zoo/official/nlp/tinybert/run_task_distill.py
model_zoo/official/nlp/tinybert/run_task_distill.py
+11
-3
model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh
...official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh
+2
-1
model_zoo/official/nlp/tinybert/src/dataset.py
model_zoo/official/nlp/tinybert/src/dataset.py
+17
-5
未找到文件。
model_zoo/official/nlp/tinybert/run_general_distill.py
浏览文件 @
4e0447b0
...
...
@@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode
from
mindspore.nn.optim
import
AdamWeightDecay
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore
import
log
as
logger
from
src.dataset
import
create_tinybert_dataset
from
src.dataset
import
create_tinybert_dataset
,
DataType
from
src.utils
import
LossCallBack
,
ModelSaveCkpt
,
BertLearningRate
from
src.gd_config
import
common_cfg
,
bert_teacher_net_cfg
,
bert_student_net_cfg
from
src.tinybert_for_gd_td
import
BertTrainWithLossScaleCell
,
BertNetworkWithLoss_gd
,
BertTrainCell
...
...
@@ -55,6 +55,7 @@ def run_general_distill():
parser
.
add_argument
(
"--load_teacher_ckpt_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
""
,
help
=
"Data path, it is better to use absolute path"
)
parser
.
add_argument
(
"--schema_dir"
,
type
=
str
,
default
=
""
,
help
=
"Schema path, it is better to use absolute path"
)
parser
.
add_argument
(
"--dataset_type"
,
type
=
str
,
default
=
"tfrecord"
,
help
=
"dataset type, default is tfrecord"
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
...
...
@@ -99,8 +100,15 @@ def run_general_distill():
student_config
=
bert_student_net_cfg
,
is_training
=
True
,
use_one_hot_embeddings
=
False
)
if
args_opt
.
dataset_type
==
"tfrecord"
:
dataset_type
=
DataType
.
TFRECORD
elif
arg_opt
.
dataset_type
==
"mindrecord"
:
dataset_type
=
DataType
.
MINDRECORD
else
:
raise
Exception
(
"dataset format is not supported yet"
)
dataset
=
create_tinybert_dataset
(
'gd'
,
bert_teacher_net_cfg
.
batch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
,
data_type
=
dataset_type
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
'dataset size: '
,
dataset_size
)
print
(
"dataset repeatcount: "
,
dataset
.
get_repeat_count
())
...
...
model_zoo/official/nlp/tinybert/run_task_distill.py
浏览文件 @
4e0447b0
...
...
@@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from
mindspore.nn.wrap.loss_scale
import
DynamicLossScaleUpdateCell
from
mindspore.nn.optim
import
AdamWeightDecay
from
mindspore
import
log
as
logger
from
src.dataset
import
create_tinybert_dataset
from
src.dataset
import
create_tinybert_dataset
,
DataType
from
src.utils
import
LossCallBack
,
ModelSaveCkpt
,
EvalCallBack
,
BertLearningRate
from
src.assessment_method
import
Accuracy
from
src.td_config
import
phase1_cfg
,
phase2_cfg
,
td_teacher_net_cfg
,
td_student_net_cfg
...
...
@@ -68,7 +68,7 @@ def parse_args():
parser
.
add_argument
(
"--schema_dir"
,
type
=
str
,
default
=
""
,
help
=
"Schema path, it is better to use absolute path"
)
parser
.
add_argument
(
"--task_name"
,
type
=
str
,
default
=
""
,
choices
=
[
"SST-2"
,
"QNLI"
,
"MNLI"
],
help
=
"The name of the task to train."
)
parser
.
add_argument
(
"--dataset_type"
,
type
=
str
,
default
=
"tfrecord"
,
help
=
"dataset type, default is tfrecord"
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -119,9 +119,17 @@ def run_predistill():
rank
=
0
device_num
=
1
if
arg_opt
.
dataset_type
==
"tfrecord"
:
dataset_type
=
DataType
.
TFRECORD
elif
arg_opt
.
dataset_type
==
"mindrecord"
:
dataset_type
=
DataType
.
MINDRECORD
else
:
raise
Exception
(
"dataset format is not supported yet"
)
dataset
=
create_tinybert_dataset
(
'td'
,
td_teacher_net_cfg
.
batch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
train_data_dir
,
args_opt
.
schema_dir
)
args_opt
.
train_data_dir
,
args_opt
.
schema_dir
,
data_tpye
=
dataset_type
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
'td1 dataset size: '
,
dataset_size
)
...
...
model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh
浏览文件 @
4e0447b0
...
...
@@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \
--save_ckpt_path
=
""
\
--load_teacher_ckpt_path
=
""
\
--data_dir
=
""
\
--schema_dir
=
""
>
log.txt 2>&1 &
--schema_dir
=
""
\
--dataset_type
=
"tfrecord"
>
log.txt 2>&1 &
model_zoo/official/nlp/tinybert/src/dataset.py
浏览文件 @
4e0447b0
...
...
@@ -16,26 +16,38 @@
"""create tinybert dataset"""
import
os
from
enum
import
Enum
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset.engine.datasets
as
de
import
mindspore.dataset.transforms.c_transforms
as
C
class
DataType
(
Enum
):
"""Enumerate supported dataset format"""
TFRECORD
=
1
MINDRECORD
=
2
def
create_tinybert_dataset
(
task
=
'td'
,
batch_size
=
32
,
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
data_dir
=
None
,
schema_dir
=
None
):
do_shuffle
=
"true"
,
data_dir
=
None
,
schema_dir
=
None
,
data_type
=
DataType
.
TFRECORD
):
"""create tinybert dataset"""
files
=
os
.
listdir
(
data_dir
)
data_files
=
[]
for
file_name
in
files
:
if
"record"
in
file_name
:
if
"record"
in
file_name
and
"db"
not
in
file_name
:
data_files
.
append
(
os
.
path
.
join
(
data_dir
,
file_name
))
if
task
==
"td"
:
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
]
else
:
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
]
ds
=
de
.
TFRecordDataset
(
data_files
,
schema_dir
,
columns_list
=
columns_list
,
shuffle
=
(
do_shuffle
==
"true"
),
num_shards
=
device_num
,
shard_id
=
rank
,
shard_equal_rows
=
True
)
if
data_type
==
DataType
.
MINDRECORD
:
ds
=
de
.
MindDataset
(
data_files
,
columns_list
=
columns_list
,
shuffle
=
(
do_shuffle
==
"true"
),
num_shards
=
device_num
,
shard_id
=
rank
)
else
:
ds
=
de
.
TFRecordDataset
(
data_files
,
schema_dir
,
columns_list
=
columns_list
,
shuffle
=
(
do_shuffle
==
"true"
),
num_shards
=
device_num
,
shard_id
=
rank
,
shard_equal_rows
=
True
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"segment_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_mask"
,
operations
=
type_cast_op
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录