Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6b994518
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6b994518
编写于
6月 04, 2020
作者:
Y
yoonlee666
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs in bert example script
上级
c0fd303e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
19 addition
and
14 deletion
+19
-14
model_zoo/bert/pretrain_eval.py
model_zoo/bert/pretrain_eval.py
+11
-9
model_zoo/bert/run_pretrain.py
model_zoo/bert/run_pretrain.py
+6
-3
model_zoo/bert/src/cluener_evaluation.py
model_zoo/bert/src/cluener_evaluation.py
+2
-2
未找到文件。
model_zoo/bert/pretrain_eval.py
浏览文件 @
6b994518
...
...
@@ -19,7 +19,7 @@ Bert evaluation script.
import
os
from
src
import
BertModel
,
GetMaskedLMOutput
from
evaluation_config
import
cfg
,
bert_net_cfg
from
src.
evaluation_config
import
cfg
,
bert_net_cfg
import
mindspore.common.dtype
as
mstype
from
mindspore
import
context
from
mindspore.common.tensor
import
Tensor
...
...
@@ -87,17 +87,18 @@ class BertPretrainEva(nn.Cell):
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
masked_pos
,
masked_ids
,
nsp_label
,
masked_weights
):
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
masked_pos
,
masked_ids
,
masked_weights
,
nsp_label
):
bs
,
_
=
self
.
shape
(
input_ids
)
probs
=
self
.
bert
(
input_ids
,
input_mask
,
token_type_id
,
masked_pos
)
index
=
self
.
argmax
(
probs
)
index
=
self
.
reshape
(
index
,
(
bs
,
-
1
))
eval_acc
=
self
.
equal
(
index
,
masked_ids
)
eval_acc1
=
self
.
cast
(
eval_acc
,
mstype
.
float32
)
acc
=
self
.
mean
(
eval_acc1
)
P
.
Print
()(
acc
)
self
.
total
+=
self
.
shape
(
probs
)[
0
]
self
.
acc
+=
self
.
sum
(
eval_acc1
)
real_acc
=
eval_acc1
*
masked_weights
acc
=
self
.
sum
(
real_acc
)
total
=
self
.
sum
(
masked_weights
)
self
.
total
+=
total
self
.
acc
+=
acc
return
acc
,
self
.
total
,
self
.
acc
...
...
@@ -107,8 +108,8 @@ def get_enwiki_512_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
ds
=
de
.
TFRecordDataset
([
cfg
.
data_file
],
cfg
.
schema_file
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"masked_lm_positions"
,
"masked_lm_ids"
,
"
next_sentence_label
s"
,
"
masked_lm_weight
s"
])
"
masked_lm_weight
s"
,
"
next_sentence_label
s"
])
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"segment_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_mask"
,
operations
=
type_cast_op
)
...
...
@@ -143,7 +144,8 @@ def MLM_eval():
Evaluate function
'''
_
,
dataset
,
net_for_pretraining
=
bert_predict
()
net
=
Model
(
net_for_pretraining
,
eval_network
=
net_for_pretraining
,
eval_indexes
=
[
0
,
1
,
2
],
metrics
=
{
myMetric
()})
net
=
Model
(
net_for_pretraining
,
eval_network
=
net_for_pretraining
,
eval_indexes
=
[
0
,
1
,
2
],
metrics
=
{
'name'
:
myMetric
()})
res
=
net
.
eval
(
dataset
,
dataset_sink_mode
=
False
)
print
(
"=============================================================="
)
for
_
,
v
in
res
.
items
():
...
...
model_zoo/bert/run_pretrain.py
浏览文件 @
6b994518
...
...
@@ -66,6 +66,8 @@ def run_pretrain():
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Checkpoint file path"
)
parser
.
add_argument
(
"--save_checkpoint_steps"
,
type
=
int
,
default
=
1000
,
help
=
"Save checkpoint steps, "
"default is 1000."
)
parser
.
add_argument
(
"--train_steps"
,
type
=
int
,
default
=-
1
,
help
=
"Training Steps, default is -1, "
"meaning run all steps according to epoch number."
)
parser
.
add_argument
(
"--save_checkpoint_num"
,
type
=
int
,
default
=
1
,
help
=
"Save checkpoint numbers, default is 1."
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
""
,
help
=
"Data path, it is better to use absolute path"
)
parser
.
add_argument
(
"--schema_dir"
,
type
=
str
,
default
=
""
,
help
=
"Schema path, it is better to use absolute path"
)
...
...
@@ -93,11 +95,12 @@ def run_pretrain():
ds
,
new_repeat_count
=
create_bert_dataset
(
args_opt
.
epoch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
if
args_opt
.
train_steps
>
0
:
new_repeat_count
=
min
(
new_repeat_count
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
netwithloss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
if
cfg
.
optimizer
==
'Lamb'
:
optimizer
=
Lamb
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
ds
.
get_repeat_count
()
,
optimizer
=
Lamb
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
start_learning_rate
=
cfg
.
Lamb
.
start_learning_rate
,
end_learning_rate
=
cfg
.
Lamb
.
end_learning_rate
,
power
=
cfg
.
Lamb
.
power
,
warmup_steps
=
cfg
.
Lamb
.
warmup_steps
,
weight_decay
=
cfg
.
Lamb
.
weight_decay
,
eps
=
cfg
.
Lamb
.
eps
)
...
...
@@ -106,7 +109,7 @@ def run_pretrain():
momentum
=
cfg
.
Momentum
.
momentum
)
elif
cfg
.
optimizer
==
'AdamWeightDecayDynamicLR'
:
optimizer
=
AdamWeightDecayDynamicLR
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
ds
.
get_repeat_count
()
,
decay_steps
=
ds
.
get_dataset_size
()
*
new_repeat_count
,
learning_rate
=
cfg
.
AdamWeightDecayDynamicLR
.
learning_rate
,
end_learning_rate
=
cfg
.
AdamWeightDecayDynamicLR
.
end_learning_rate
,
power
=
cfg
.
AdamWeightDecayDynamicLR
.
power
,
...
...
model_zoo/bert/src/cluener_evaluation.py
浏览文件 @
6b994518
...
...
@@ -19,8 +19,8 @@ import json
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
from
mindspore.common.tensor
import
Tensor
import
tokenization
from
sample_process
import
label_generation
,
process_one_example_p
from
.
import
tokenization
from
.
sample_process
import
label_generation
,
process_one_example_p
from
.evaluation_config
import
cfg
from
.CRF
import
postprocess
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录