Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6e21618f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e21618f
编写于
8月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3798 add shuffle switch for finetune dataset
Merge pull request !3798 from yoonlee666/master-finetune
上级
2b565627
9bdece71
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
43 addition
and
20 deletion
+43
-20
model_zoo/official/nlp/bert/run_classifier.py
model_zoo/official/nlp/bert/run_classifier.py
+12
-7
model_zoo/official/nlp/bert/run_ner.py
model_zoo/official/nlp/bert/run_ner.py
+8
-2
model_zoo/official/nlp/bert/run_squad.py
model_zoo/official/nlp/bert/run_squad.py
+8
-2
model_zoo/official/nlp/bert/scripts/run_classifier.sh
model_zoo/official/nlp/bert/scripts/run_classifier.sh
+2
-0
model_zoo/official/nlp/bert/scripts/run_ner.sh
model_zoo/official/nlp/bert/scripts/run_ner.sh
+2
-0
model_zoo/official/nlp/bert/scripts/run_squad.sh
model_zoo/official/nlp/bert/scripts/run_squad.sh
+2
-0
model_zoo/official/nlp/bert/src/assessment_method.py
model_zoo/official/nlp/bert/src/assessment_method.py
+0
-1
model_zoo/official/nlp/bert/src/dataset.py
model_zoo/official/nlp/bert/src/dataset.py
+9
-8
未找到文件。
model_zoo/official/nlp/bert/run_classifier.py
浏览文件 @
6e21618f
...
@@ -133,14 +133,17 @@ def run_classifier():
...
@@ -133,14 +133,17 @@ def run_classifier():
"""run classifier task"""
"""run classifier task"""
parser
=
argparse
.
ArgumentParser
(
description
=
"run classifier"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"run classifier"
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
,
help
=
"Device type, default is Ascend"
)
parser
.
add_argument
(
"--device_target"
,
type
=
str
,
default
=
"Ascend"
,
help
=
"Device type, default is Ascend"
)
parser
.
add_argument
(
"--assessment_method"
,
type
=
str
,
default
=
"accuracy"
,
help
=
"assessment_method include: "
parser
.
add_argument
(
"--assessment_method"
,
type
=
str
,
default
=
"accuracy"
,
"[MCC, Spearman_correlation, "
help
=
"assessment_method including [MCC, Spearman_correlation, Accuracy], default is accuracy"
)
"Accuracy], default is accuracy"
)
parser
.
add_argument
(
"--do_train"
,
type
=
str
,
default
=
"false"
,
help
=
"Enable train, default is false"
)
parser
.
add_argument
(
"--do_train"
,
type
=
str
,
default
=
"false"
,
help
=
"Eable train, default is false"
)
parser
.
add_argument
(
"--do_eval"
,
type
=
str
,
default
=
"false"
,
help
=
"Enable eval, default is false"
)
parser
.
add_argument
(
"--do_eval"
,
type
=
str
,
default
=
"false"
,
help
=
"Eable eval, default is false"
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--train_data_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable train data shuffle, default is true"
)
parser
.
add_argument
(
"--eval_data_shuffle"
,
type
=
str
,
default
=
"false"
,
help
=
"Enable eval data shuffle, default is false"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--load_pretrain_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--load_pretrain_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--load_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
parser
.
add_argument
(
"--load_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Load checkpoint file path"
)
...
@@ -182,7 +185,8 @@ def run_classifier():
...
@@ -182,7 +185,8 @@ def run_classifier():
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
train_data_file_path
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
schema_file_path
=
args_opt
.
schema_file_path
,
do_shuffle
=
(
args_opt
.
train_data_shuffle
.
lower
()
==
"true"
))
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
...
@@ -197,7 +201,8 @@ def run_classifier():
...
@@ -197,7 +201,8 @@ def run_classifier():
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
eval_data_file_path
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
schema_file_path
=
args_opt
.
schema_file_path
,
do_shuffle
=
(
args_opt
.
eval_data_shuffle
.
lower
()
==
"true"
))
do_eval
(
ds
,
BertCLS
,
args_opt
.
num_class
,
assessment_method
,
load_finetune_checkpoint_path
)
do_eval
(
ds
,
BertCLS
,
args_opt
.
num_class
,
assessment_method
,
load_finetune_checkpoint_path
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
model_zoo/official/nlp/bert/run_ner.py
浏览文件 @
6e21618f
...
@@ -150,6 +150,10 @@ def run_ner():
...
@@ -150,6 +150,10 @@ def run_ner():
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--train_data_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable train data shuffle, default is true"
)
parser
.
add_argument
(
"--eval_data_shuffle"
,
type
=
str
,
default
=
"false"
,
help
=
"Enable eval data shuffle, default is false"
)
parser
.
add_argument
(
"--vocab_file_path"
,
type
=
str
,
default
=
""
,
help
=
"Vocab file path, used in clue benchmark"
)
parser
.
add_argument
(
"--vocab_file_path"
,
type
=
str
,
default
=
""
,
help
=
"Vocab file path, used in clue benchmark"
)
parser
.
add_argument
(
"--label2id_file_path"
,
type
=
str
,
default
=
""
,
help
=
"label2id file path, used in clue benchmark"
)
parser
.
add_argument
(
"--label2id_file_path"
,
type
=
str
,
default
=
""
,
help
=
"label2id file path, used in clue benchmark"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
...
@@ -208,7 +212,8 @@ def run_ner():
...
@@ -208,7 +212,8 @@ def run_ner():
if
args_opt
.
do_train
.
lower
()
==
"true"
:
if
args_opt
.
do_train
.
lower
()
==
"true"
:
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
train_data_file_path
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
schema_file_path
=
args_opt
.
schema_file_path
,
do_shuffle
=
(
args_opt
.
train_data_shuffle
.
lower
()
==
"true"
))
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
...
@@ -222,7 +227,8 @@ def run_ner():
...
@@ -222,7 +227,8 @@ def run_ner():
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
eval_data_file_path
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
schema_file_path
=
args_opt
.
schema_file_path
,
do_shuffle
=
(
args_opt
.
eval_data_shuffle
.
lower
()
==
"true"
))
do_eval
(
ds
,
BertNER
,
args_opt
.
use_crf
,
number_labels
,
assessment_method
,
args_opt
.
eval_data_file_path
,
do_eval
(
ds
,
BertNER
,
args_opt
.
use_crf
,
number_labels
,
assessment_method
,
args_opt
.
eval_data_file_path
,
load_finetune_checkpoint_path
,
args_opt
.
vocab_file_path
,
args_opt
.
label2id_file_path
,
tag_to_index
)
load_finetune_checkpoint_path
,
args_opt
.
vocab_file_path
,
args_opt
.
label2id_file_path
,
tag_to_index
)
...
...
model_zoo/official/nlp/bert/run_squad.py
浏览文件 @
6e21618f
...
@@ -140,6 +140,10 @@ def run_squad():
...
@@ -140,6 +140,10 @@ def run_squad():
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--epoch_num"
,
type
=
int
,
default
=
"1"
,
help
=
"Epoch number, default is 1."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--num_class"
,
type
=
int
,
default
=
"2"
,
help
=
"The number of class, default is 2."
)
parser
.
add_argument
(
"--train_data_shuffle"
,
type
=
str
,
default
=
"true"
,
help
=
"Enable train data shuffle, default is true"
)
parser
.
add_argument
(
"--eval_data_shuffle"
,
type
=
str
,
default
=
"false"
,
help
=
"Enable eval data shuffle, default is false"
)
parser
.
add_argument
(
"--vocab_file_path"
,
type
=
str
,
default
=
""
,
help
=
"Vocab file path"
)
parser
.
add_argument
(
"--vocab_file_path"
,
type
=
str
,
default
=
""
,
help
=
"Vocab file path"
)
parser
.
add_argument
(
"--eval_json_path"
,
type
=
str
,
default
=
""
,
help
=
"Evaluation json file path, can be eval.json"
)
parser
.
add_argument
(
"--eval_json_path"
,
type
=
str
,
default
=
""
,
help
=
"Evaluation json file path, can be eval.json"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
parser
.
add_argument
(
"--save_finetune_checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Save checkpoint path"
)
...
@@ -186,7 +190,8 @@ def run_squad():
...
@@ -186,7 +190,8 @@ def run_squad():
if
args_opt
.
do_train
.
lower
()
==
"true"
:
if
args_opt
.
do_train
.
lower
()
==
"true"
:
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
data_file_path
=
args_opt
.
train_data_file_path
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
schema_file_path
=
args_opt
.
schema_file_path
,
do_shuffle
=
(
args_opt
.
train_data_shuffle
.
lower
()
==
"true"
))
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
save_finetune_checkpoint_path
==
""
:
if
save_finetune_checkpoint_path
==
""
:
...
@@ -199,7 +204,8 @@ def run_squad():
...
@@ -199,7 +204,8 @@ def run_squad():
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
data_file_path
=
args_opt
.
eval_data_file_path
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
,
is_training
=
False
)
schema_file_path
=
args_opt
.
schema_file_path
,
is_training
=
False
,
do_shuffle
=
(
args_opt
.
eval_data_shuffle
.
lower
()
==
"true"
))
do_eval
(
ds
,
args_opt
.
vocab_file_path
,
args_opt
.
eval_json_path
,
do_eval
(
ds
,
args_opt
.
vocab_file_path
,
args_opt
.
eval_json_path
,
load_finetune_checkpoint_path
,
bert_net_cfg
.
seq_length
)
load_finetune_checkpoint_path
,
bert_net_cfg
.
seq_length
)
...
...
model_zoo/official/nlp/bert/scripts/run_classifier.sh
浏览文件 @
6e21618f
...
@@ -34,6 +34,8 @@ python ${PROJECT_DIR}/../run_classifier.py \
...
@@ -34,6 +34,8 @@ python ${PROJECT_DIR}/../run_classifier.py \
--device_id
=
0
\
--device_id
=
0
\
--epoch_num
=
1
\
--epoch_num
=
1
\
--num_class
=
2
\
--num_class
=
2
\
--train_data_shuffle
=
"true"
\
--eval_data_shuffle
=
"false"
\
--save_finetune_checkpoint_path
=
""
\
--save_finetune_checkpoint_path
=
""
\
--load_pretrain_checkpoint_path
=
""
\
--load_pretrain_checkpoint_path
=
""
\
--load_finetune_checkpoint_path
=
""
\
--load_finetune_checkpoint_path
=
""
\
...
...
model_zoo/official/nlp/bert/scripts/run_ner.sh
浏览文件 @
6e21618f
...
@@ -35,6 +35,8 @@ python ${PROJECT_DIR}/../run_ner.py \
...
@@ -35,6 +35,8 @@ python ${PROJECT_DIR}/../run_ner.py \
--device_id
=
0
\
--device_id
=
0
\
--epoch_num
=
1
\
--epoch_num
=
1
\
--num_class
=
2
\
--num_class
=
2
\
--train_data_shuffle
=
"true"
\
--eval_data_shuffle
=
"false"
\
--vocab_file_path
=
""
\
--vocab_file_path
=
""
\
--label2id_file_path
=
""
\
--label2id_file_path
=
""
\
--save_finetune_checkpoint_path
=
""
\
--save_finetune_checkpoint_path
=
""
\
...
...
model_zoo/official/nlp/bert/scripts/run_squad.sh
浏览文件 @
6e21618f
...
@@ -33,6 +33,8 @@ python ${PROJECT_DIR}/../run_squad.py \
...
@@ -33,6 +33,8 @@ python ${PROJECT_DIR}/../run_squad.py \
--device_id
=
0
\
--device_id
=
0
\
--epoch_num
=
1
\
--epoch_num
=
1
\
--num_class
=
2
\
--num_class
=
2
\
--train_data_shuffle
=
"true"
\
--eval_data_shuffle
=
"false"
\
--vocab_file_path
=
""
\
--vocab_file_path
=
""
\
--eval_json_path
=
""
\
--eval_json_path
=
""
\
--save_finetune_checkpoint_path
=
""
\
--save_finetune_checkpoint_path
=
""
\
...
...
model_zoo/official/nlp/bert/src/assessment_method.py
浏览文件 @
6e21618f
...
@@ -34,7 +34,6 @@ class Accuracy():
...
@@ -34,7 +34,6 @@ class Accuracy():
logit_id
=
np
.
argmax
(
logits
,
axis
=-
1
)
logit_id
=
np
.
argmax
(
logits
,
axis
=-
1
)
self
.
acc_num
+=
np
.
sum
(
labels
==
logit_id
)
self
.
acc_num
+=
np
.
sum
(
labels
==
logit_id
)
self
.
total_num
+=
len
(
labels
)
self
.
total_num
+=
len
(
labels
)
print
(
"=========================accuracy is "
,
self
.
acc_num
/
self
.
total_num
)
class
F1
():
class
F1
():
'''
'''
...
...
model_zoo/official/nlp/bert/src/dataset.py
浏览文件 @
6e21618f
...
@@ -53,11 +53,11 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
...
@@ -53,11 +53,11 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
def
create_ner_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
def
create_ner_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
data_file_path
=
None
,
schema_file_path
=
None
):
data_file_path
=
None
,
schema_file_path
=
None
,
do_shuffle
=
True
):
"""create finetune or evaluation dataset"""
"""create finetune or evaluation dataset"""
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
])
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
]
,
shuffle
=
do_shuffle
)
if
assessment_method
==
"Spearman_correlation"
:
if
assessment_method
==
"Spearman_correlation"
:
type_cast_op_float
=
C
.
TypeCast
(
mstype
.
float32
)
type_cast_op_float
=
C
.
TypeCast
(
mstype
.
float32
)
ds
=
ds
.
map
(
input_columns
=
"label_ids"
,
operations
=
type_cast_op_float
)
ds
=
ds
.
map
(
input_columns
=
"label_ids"
,
operations
=
type_cast_op_float
)
...
@@ -76,11 +76,11 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy
...
@@ -76,11 +76,11 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy
def
create_classification_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
def
create_classification_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
assessment_method
=
"accuracy"
,
data_file_path
=
None
,
schema_file_path
=
None
):
data_file_path
=
None
,
schema_file_path
=
None
,
do_shuffle
=
True
):
"""create finetune or evaluation dataset"""
"""create finetune or evaluation dataset"""
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
])
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"label_ids"
]
,
shuffle
=
do_shuffle
)
if
assessment_method
==
"Spearman_correlation"
:
if
assessment_method
==
"Spearman_correlation"
:
type_cast_op_float
=
C
.
TypeCast
(
mstype
.
float32
)
type_cast_op_float
=
C
.
TypeCast
(
mstype
.
float32
)
ds
=
ds
.
map
(
input_columns
=
"label_ids"
,
operations
=
type_cast_op_float
)
ds
=
ds
.
map
(
input_columns
=
"label_ids"
,
operations
=
type_cast_op_float
)
...
@@ -98,14 +98,15 @@ def create_classification_dataset(batch_size=1, repeat_count=1, assessment_metho
...
@@ -98,14 +98,15 @@ def create_classification_dataset(batch_size=1, repeat_count=1, assessment_metho
return
ds
return
ds
def
create_squad_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
data_file_path
=
None
,
schema_file_path
=
None
,
is_training
=
True
):
def
create_squad_dataset
(
batch_size
=
1
,
repeat_count
=
1
,
data_file_path
=
None
,
schema_file_path
=
None
,
is_training
=
True
,
do_shuffle
=
True
):
"""create finetune or evaluation dataset"""
"""create finetune or evaluation dataset"""
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
if
is_training
:
if
is_training
:
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
ds
=
de
.
TFRecordDataset
([
data_file_path
],
schema_file_path
if
schema_file_path
!=
""
else
None
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"start_positions"
,
"
start_positions"
,
"end_positions"
,
"
end_positions"
,
"unique_ids"
,
"is_impossible"
]
,
"unique_ids"
,
"is_impossible"
]
)
shuffle
=
do_shuffle
)
ds
=
ds
.
map
(
input_columns
=
"start_positions"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"start_positions"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"end_positions"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"end_positions"
,
operations
=
type_cast_op
)
else
:
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录