Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c0fd303e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c0fd303e
编写于
6月 01, 2020
作者:
Y
yoonlee666
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mlm eval script in bert example
上级
0920660e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
156 addition
and
0 deletion
+156
-0
model_zoo/bert/pretrain_eval.py
model_zoo/bert/pretrain_eval.py
+156
-0
未找到文件。
model_zoo/bert/pretrain_eval.py
0 → 100644
浏览文件 @
c0fd303e
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bert evaluation script.
"""
import
os
from
src
import
BertModel
,
GetMaskedLMOutput
from
evaluation_config
import
cfg
,
bert_net_cfg
import
mindspore.common.dtype
as
mstype
from
mindspore
import
context
from
mindspore.common.tensor
import
Tensor
import
mindspore.dataset
as
de
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
import
mindspore.nn
as
nn
from
mindspore.nn.metrics
import
Metric
from
mindspore.ops
import
operations
as
P
from
mindspore.common.parameter
import
Parameter
class
myMetric
(
Metric
):
'''
Self-defined Metric as a callback.
'''
def
__init__
(
self
):
super
(
myMetric
,
self
).
__init__
()
self
.
clear
()
def
clear
(
self
):
self
.
total_num
=
0
self
.
acc_num
=
0
def
update
(
self
,
*
inputs
):
total_num
=
self
.
_convert_data
(
inputs
[
0
])
acc_num
=
self
.
_convert_data
(
inputs
[
1
])
self
.
total_num
=
total_num
self
.
acc_num
=
acc_num
def
eval
(
self
):
return
self
.
acc_num
/
self
.
total_num
class
GetLogProbs
(
nn
.
Cell
):
'''
Get MaskedLM prediction scores
'''
def
__init__
(
self
,
config
):
super
(
GetLogProbs
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
,
False
)
self
.
cls1
=
GetMaskedLMOutput
(
config
)
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
masked_pos
):
sequence_output
,
_
,
embedding_table
=
self
.
bert
(
input_ids
,
token_type_id
,
input_mask
)
prediction_scores
=
self
.
cls1
(
sequence_output
,
embedding_table
,
masked_pos
)
return
prediction_scores
class
BertPretrainEva
(
nn
.
Cell
):
'''
Evaluate MaskedLM prediction scores
'''
def
__init__
(
self
,
config
):
super
(
BertPretrainEva
,
self
).
__init__
()
self
.
bert
=
GetLogProbs
(
config
)
self
.
argmax
=
P
.
Argmax
(
axis
=-
1
,
output_type
=
mstype
.
int32
)
self
.
equal
=
P
.
Equal
()
self
.
mean
=
P
.
ReduceMean
()
self
.
sum
=
P
.
ReduceSum
()
self
.
total
=
Parameter
(
Tensor
([
0
],
mstype
.
float32
),
name
=
'total'
)
self
.
acc
=
Parameter
(
Tensor
([
0
],
mstype
.
float32
),
name
=
'acc'
)
self
.
reshape
=
P
.
Reshape
()
self
.
shape
=
P
.
Shape
()
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
input_ids
,
input_mask
,
token_type_id
,
masked_pos
,
masked_ids
,
nsp_label
,
masked_weights
):
bs
,
_
=
self
.
shape
(
input_ids
)
probs
=
self
.
bert
(
input_ids
,
input_mask
,
token_type_id
,
masked_pos
)
index
=
self
.
argmax
(
probs
)
index
=
self
.
reshape
(
index
,
(
bs
,
-
1
))
eval_acc
=
self
.
equal
(
index
,
masked_ids
)
eval_acc1
=
self
.
cast
(
eval_acc
,
mstype
.
float32
)
acc
=
self
.
mean
(
eval_acc1
)
P
.
Print
()(
acc
)
self
.
total
+=
self
.
shape
(
probs
)[
0
]
self
.
acc
+=
self
.
sum
(
eval_acc1
)
return
acc
,
self
.
total
,
self
.
acc
def
get_enwiki_512_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
distribute_file
=
''
):
'''
Get enwiki seq_length=512 dataset
'''
ds
=
de
.
TFRecordDataset
([
cfg
.
data_file
],
cfg
.
schema_file
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"masked_lm_positions"
,
"masked_lm_ids"
,
"next_sentence_labels"
,
"masked_lm_weights"
])
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"segment_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_mask"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_positions"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"next_sentence_labels"
,
operations
=
type_cast_op
)
ds
=
ds
.
repeat
(
repeat_count
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
return
ds
def
bert_predict
():
'''
Predict function
'''
devid
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
devid
)
dataset
=
get_enwiki_512_dataset
(
bert_net_cfg
.
batch_size
,
1
)
net_for_pretraining
=
BertPretrainEva
(
bert_net_cfg
)
net_for_pretraining
.
set_train
(
False
)
param_dict
=
load_checkpoint
(
cfg
.
finetune_ckpt
)
load_param_into_net
(
net_for_pretraining
,
param_dict
)
model
=
Model
(
net_for_pretraining
)
return
model
,
dataset
,
net_for_pretraining
def
MLM_eval
():
'''
Evaluate function
'''
_
,
dataset
,
net_for_pretraining
=
bert_predict
()
net
=
Model
(
net_for_pretraining
,
eval_network
=
net_for_pretraining
,
eval_indexes
=
[
0
,
1
,
2
],
metrics
=
{
myMetric
()})
res
=
net
.
eval
(
dataset
,
dataset_sink_mode
=
False
)
print
(
"=============================================================="
)
for
_
,
v
in
res
.
items
():
print
(
"Accuracy is: "
)
print
(
v
)
print
(
"=============================================================="
)
if
__name__
==
"__main__"
:
MLM_eval
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录