Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6fdf3809
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6fdf3809
编写于
7月 21, 2020
作者:
C
chenhaozhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bert scripts to adapt the new concept of repeatcount in minddata
上级
ad651f38
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
27 addition
and
29 deletion
+27
-29
model_zoo/official/nlp/bert/run_pretrain.py
model_zoo/official/nlp/bert/run_pretrain.py
+23
-19
model_zoo/official/nlp/bert/src/dataset.py
model_zoo/official/nlp/bert/src/dataset.py
+3
-10
model_zoo/official/nlp/bert/src/utils.py
model_zoo/official/nlp/bert/src/utils.py
+1
-0
未找到文件。
model_zoo/official/nlp/bert/run_pretrain.py
浏览文件 @
6fdf3809
...
...
@@ -64,7 +64,6 @@ def run_pretrain():
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
reserve_class_name_in_scope
=
False
)
context
.
set_context
(
variable_memory_max_size
=
"30GB"
)
ckpt_save_dir
=
args_opt
.
save_checkpoint_path
if
args_opt
.
distribute
==
"true"
:
if
args_opt
.
device_target
==
'Ascend'
:
...
...
@@ -99,47 +98,49 @@ def run_pretrain():
logger
.
warning
(
'Gpu only support fp32 temporarily, run with fp32.'
)
bert_net_cfg
.
compute_type
=
mstype
.
float32
ds
=
create_bert_dataset
(
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
net_with_loss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
ds
=
create_bert_dataset
(
1
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
new_repeat_count
=
args_opt
.
epoch_size
new_repeat_count
=
args_opt
.
epoch_size
*
ds
.
get_dataset_size
()
//
args_opt
.
data_sink_steps
if
args_opt
.
train_steps
>
0
:
new_repeat_count
=
min
(
args_opt
.
epoch_size
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
netwithloss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
new_repeat_count
=
min
(
new_repeat_count
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
else
:
args_opt
.
train_steps
=
args_opt
.
epoch_size
*
ds
.
get_dataset_size
()
if
cfg
.
optimizer
==
'Lamb'
:
lr_schedule
=
BertLearningRate
(
learning_rate
=
cfg
.
Lamb
.
learning_rate
,
end_learning_rate
=
cfg
.
Lamb
.
end_learning_rate
,
warmup_steps
=
cfg
.
Lamb
.
warmup_steps
,
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
decay_steps
=
args_opt
.
train_steps
,
power
=
cfg
.
Lamb
.
power
)
params
=
net_with_loss
.
trainable_params
()
decay_params
=
list
(
filter
(
cfg
.
Lamb
.
decay_filter
,
params
))
other_params
=
list
(
filter
(
lambda
x
:
x
not
in
decay_params
,
params
))
group_params
=
[{
'params'
:
decay_params
,
'weight_decay'
:
cfg
.
Lamb
.
weight_decay
},
{
'params'
:
other_params
}]
{
'params'
:
other_params
},
{
'order_params'
:
params
}]
optimizer
=
Lamb
(
group_params
,
learning_rate
=
lr_schedule
,
eps
=
cfg
.
Lamb
.
eps
)
elif
cfg
.
optimizer
==
'Momentum'
:
optimizer
=
Momentum
(
net
with
loss
.
trainable_params
(),
learning_rate
=
cfg
.
Momentum
.
learning_rate
,
optimizer
=
Momentum
(
net
_with_
loss
.
trainable_params
(),
learning_rate
=
cfg
.
Momentum
.
learning_rate
,
momentum
=
cfg
.
Momentum
.
momentum
)
elif
cfg
.
optimizer
==
'AdamWeightDecay'
:
lr_schedule
=
BertLearningRate
(
learning_rate
=
cfg
.
AdamWeightDecay
.
learning_rate
,
end_learning_rate
=
cfg
.
AdamWeightDecay
.
end_learning_rate
,
warmup_steps
=
cfg
.
AdamWeightDecay
.
warmup_steps
,
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
decay_steps
=
args_opt
.
train_steps
,
power
=
cfg
.
AdamWeightDecay
.
power
)
params
=
net_with_loss
.
trainable_params
()
decay_params
=
list
(
filter
(
cfg
.
AdamWeightDecay
.
decay_filter
,
params
))
other_params
=
list
(
filter
(
lambda
x
:
x
not
in
decay_params
,
params
))
group_params
=
[{
'params'
:
decay_params
,
'weight_decay'
:
cfg
.
AdamWeightDecay
.
weight_decay
},
{
'params'
:
other_params
,
'weight_decay'
:
0.0
}]
{
'params'
:
other_params
,
'weight_decay'
:
0.0
},
{
'order_params'
:
params
}]
optimizer
=
AdamWeightDecay
(
group_params
,
learning_rate
=
lr_schedule
,
eps
=
cfg
.
AdamWeightDecay
.
eps
)
else
:
raise
ValueError
(
"Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
.
format
(
cfg
.
optimizer
))
callback
=
[
TimeMonitor
(
ds
.
get_dataset_size
()
),
LossCallBack
()]
callback
=
[
TimeMonitor
(
args_opt
.
data_sink_steps
),
LossCallBack
()]
if
args_opt
.
enable_save_ckpt
==
"true"
:
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
args_opt
.
save_checkpoint_steps
,
keep_checkpoint_max
=
args_opt
.
save_checkpoint_num
)
...
...
@@ -148,19 +149,22 @@ def run_pretrain():
if
args_opt
.
load_checkpoint_path
:
param_dict
=
load_checkpoint
(
args_opt
.
load_checkpoint_path
)
load_param_into_net
(
net
with
loss
,
param_dict
)
load_param_into_net
(
net
_with_
loss
,
param_dict
)
if
args_opt
.
enable_lossscale
==
"true"
:
update_cell
=
DynamicLossScaleUpdateCell
(
loss_scale_value
=
cfg
.
loss_scale_value
,
scale_factor
=
cfg
.
scale_factor
,
scale_window
=
cfg
.
scale_window
)
net
withgrads
=
BertTrainOneStepWithLossScaleCell
(
netwith
loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
net
_with_grads
=
BertTrainOneStepWithLossScaleCell
(
net_with_
loss
,
optimizer
=
optimizer
,
scale_update_cell
=
update_cell
)
else
:
netwithgrads
=
BertTrainOneStepCell
(
netwithloss
,
optimizer
=
optimizer
)
net_with_grads
=
BertTrainOneStepCell
(
net_with_loss
,
optimizer
=
optimizer
)
model
=
Model
(
net_with_grads
)
model
.
train
(
new_repeat_count
,
ds
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
"true"
),
sink_size
=
args_opt
.
data_sink_steps
)
model
=
Model
(
netwithgrads
)
model
.
train
(
new_repeat_count
,
ds
,
callbacks
=
callback
,
dataset_sink_mode
=
(
args_opt
.
enable_data_sink
==
"true"
))
if
__name__
==
'__main__'
:
numpy
.
random
.
seed
(
0
)
run_pretrain
()
model_zoo/official/nlp/bert/src/dataset.py
浏览文件 @
6fdf3809
...
...
@@ -23,11 +23,9 @@ from mindspore import log as logger
from
.config
import
bert_net_cfg
def
create_bert_dataset
(
epoch_size
=
1
,
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
enable_data_sink
=
"true"
,
data_sink_steps
=
1
,
data_dir
=
None
,
schema_dir
=
None
):
def
create_bert_dataset
(
device_num
=
1
,
rank
=
0
,
do_shuffle
=
"true"
,
data_dir
=
None
,
schema_dir
=
None
):
"""create train dataset"""
# apply repeat operations
repeat_count
=
epoch_size
files
=
os
.
listdir
(
data_dir
)
data_files
=
[]
for
file_name
in
files
:
...
...
@@ -40,11 +38,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
num_shards
=
device_num
,
shard_id
=
rank
,
shard_equal_rows
=
True
)
ori_dataset_size
=
ds
.
get_dataset_size
()
print
(
'origin dataset size: '
,
ori_dataset_size
)
new_size
=
ori_dataset_size
if
enable_data_sink
==
"true"
:
new_size
=
data_sink_steps
*
bert_net_cfg
.
batch_size
ds
.
set_dataset_size
(
new_size
)
new_repeat_count
=
int
(
repeat_count
*
ori_dataset_size
//
ds
.
get_dataset_size
())
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_positions"
,
operations
=
type_cast_op
)
...
...
@@ -55,8 +48,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
# apply batch operations
ds
=
ds
.
batch
(
bert_net_cfg
.
batch_size
,
drop_remainder
=
True
)
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeatcount: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
,
new_repeat_count
logger
.
info
(
"repeat
count: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
def
create_ner_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
...
...
model_zoo/official/nlp/bert/src/utils.py
浏览文件 @
6fdf3809
...
...
@@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation.
"""
import
os
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore.common.tensor
import
Tensor
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录