Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
1373e294
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1373e294
编写于
4月 13, 2020
作者:
X
xyzhou-puck
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add leveldb reader for bert
上级
b2f94aa8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
647 addition
and
8 deletion
+647
-8
examples/bert_leveldb/bert.yaml
examples/bert_leveldb/bert.yaml
+27
-0
examples/bert_leveldb/bert_classifier.py
examples/bert_leveldb/bert_classifier.py
+116
-0
examples/bert_leveldb/cls.py
examples/bert_leveldb/cls.py
+73
-0
examples/bert_leveldb/nohup.out
examples/bert_leveldb/nohup.out
+312
-0
examples/bert_leveldb/run_classifier_single_gpu.sh
examples/bert_leveldb/run_classifier_single_gpu.sh
+29
-0
hapi/text/bert/dataloader.py
hapi/text/bert/dataloader.py
+90
-8
未找到文件。
examples/bert_leveldb/bert.yaml
0 → 100644
浏览文件 @
1373e294
bert_config_path
:
"
./config/bert_config.json"
init_checkpoint
:
None
init_pretraining_params
:
None
checkpoints
:
"
./saved_model"
epoch
:
3
learning_rate
:
0.0001
lr_scheduler
:
"
linear_warmup_decay"
weight_decay
:
0.01
warmup_proportion
:
0.1
save_steps
:
100000
validation_steps
:
100000
loss_scaling
:
1.0
skip_steps
:
100
data_dir
:
None
vocab_path
:
None
max_seq_len
:
512
batch_size
:
32
in_tokens
:
False
do_lower_case
:
True
random_seed
:
5512
use_cuda
:
False
shuffle
:
True
do_train
:
True
do_test
:
True
use_data_parallel
:
False
verbose
:
False
examples/bert_leveldb/bert_classifier.py
0 → 100755
浏览文件 @
1373e294
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT fine-tuning in Paddle Dygraph Mode."""
import
paddle.fluid
as
fluid
from
hapi.metrics
import
Accuracy
from
hapi.configure
import
Config
from
hapi.model
import
set_device
,
Model
,
SoftmaxWithCrossEntropy
,
Input
from
cls
import
ClsModelLayer
import
hapi.text.tokenizer.tokenization
as
tokenization
from
hapi.text.bert
import
Optimizer
,
BertConfig
,
BertDataLoader
,
BertInputExample
def
train
():
config
=
Config
(
yaml_file
=
"./bert.yaml"
)
config
.
build
()
config
.
Print
()
device
=
set_device
(
"gpu"
if
config
.
use_cuda
else
"cpu"
)
fluid
.
enable_dygraph
(
device
)
bert_config
=
BertConfig
(
config
.
bert_config_path
)
bert_config
.
print_config
()
trainer_count
=
fluid
.
dygraph
.
parallel
.
Env
().
nranks
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
config
.
vocab_path
,
do_lower_case
=
config
.
do_lower_case
)
def
mnli_line_processor
(
line_id
,
line
):
if
line_id
==
"0"
:
return
None
uid
=
tokenization
.
convert_to_unicode
(
line
[
0
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
8
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
9
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
if
label
not
in
[
"contradiction"
,
"entailment"
,
"neutral"
]:
label
=
"contradiction"
return
BertInputExample
(
uid
=
uid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
)
bert_dataloader
=
BertDataLoader
(
"./data/glue_data/MNLI/train.tsv"
,
tokenizer
,
[
"contradiction"
,
"entailment"
,
"neutral"
],
max_seq_length
=
64
,
batch_size
=
32
,
line_processor
=
mnli_line_processor
,
mode
=
"leveldb"
)
num_train_examples
=
len
(
bert_dataloader
.
dataset
)
max_train_steps
=
config
.
epoch
*
num_train_examples
//
config
.
batch_size
//
trainer_count
warmup_steps
=
int
(
max_train_steps
*
config
.
warmup_proportion
)
print
(
"Trainer count: %d"
%
trainer_count
)
print
(
"Num train examples: %d"
%
num_train_examples
)
print
(
"Max train steps: %d"
%
max_train_steps
)
print
(
"Num warmup steps: %d"
%
warmup_steps
)
inputs
=
[
Input
(
[
None
,
None
],
'int64'
,
name
=
'src_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'pos_ids'
),
Input
(
[
None
,
None
],
'int64'
,
name
=
'sent_ids'
),
Input
(
[
None
,
None
],
'float32'
,
name
=
'input_mask'
)
]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
cls_model
=
ClsModelLayer
(
config
,
bert_config
,
len
([
"contradiction"
,
"entailment"
,
"neutral"
]),
is_training
=
True
,
return_pooled_out
=
True
)
optimizer
=
Optimizer
(
warmup_steps
=
warmup_steps
,
num_train_steps
=
max_train_steps
,
learning_rate
=
config
.
learning_rate
,
model_cls
=
cls_model
,
weight_decay
=
config
.
weight_decay
,
scheduler
=
config
.
lr_scheduler
,
loss_scaling
=
config
.
loss_scaling
,
parameter_list
=
cls_model
.
parameters
())
cls_model
.
prepare
(
optimizer
,
SoftmaxWithCrossEntropy
(),
Accuracy
(
topk
=
(
1
,
2
)),
inputs
,
labels
,
device
=
device
)
cls_model
.
bert_layer
.
init_parameters
(
config
.
init_pretraining_params
,
verbose
=
config
.
verbose
)
cls_model
.
fit
(
train_data
=
bert_dataloader
.
dataloader
,
epochs
=
config
.
epoch
)
return
cls_model
if
__name__
==
'__main__'
:
cls_model
=
train
()
examples/bert_leveldb/cls.py
0 → 100644
浏览文件 @
1373e294
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"dygraph transformer layers"
import
six
import
json
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Linear
,
Layer
from
hapi.text.bert
import
BertEncoder
from
hapi.model
import
Model
class
ClsModelLayer
(
Model
):
"""
classify model
"""
def
__init__
(
self
,
args
,
config
,
num_labels
,
is_training
=
True
,
return_pooled_out
=
True
,
use_fp16
=
False
):
super
(
ClsModelLayer
,
self
).
__init__
()
self
.
config
=
config
self
.
is_training
=
is_training
self
.
use_fp16
=
use_fp16
self
.
loss_scaling
=
args
.
loss_scaling
self
.
bert_layer
=
BertEncoder
(
config
=
self
.
config
,
return_pooled_out
=
True
,
use_fp16
=
self
.
use_fp16
)
self
.
cls_fc
=
Linear
(
input_dim
=
self
.
config
[
"hidden_size"
],
output_dim
=
num_labels
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_out_w"
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"cls_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
def
forward
(
self
,
src_ids
,
position_ids
,
sentence_ids
,
input_mask
):
"""
forward
"""
enc_output
,
next_sent_feat
=
self
.
bert_layer
(
src_ids
,
position_ids
,
sentence_ids
,
input_mask
)
cls_feats
=
fluid
.
layers
.
dropout
(
x
=
next_sent_feat
,
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
logits
=
self
.
cls_fc
(
cls_feats
)
return
logits
examples/bert_leveldb/nohup.out
0 → 100644
浏览文件 @
1373e294
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
2020-04-13 13:08:30,568-WARNING: use_shared_memory can only be used in multi-process mode(num_workers > 0), set use_shared_memory as False
W0413 13:08:31.584532 119379 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0413 13:08:31.589192 119379 device_context.cc:245] device: 0, cuDNN Version: 7.5.
----------------------------------------------------------------------
bert_config_path: ./data/pretrained_models/uncased_L-12_H-768_A-12//bert_config.json
init_checkpoint: None
init_pretraining_params: ./data/pretrained_models/uncased_L-12_H-768_A-12//dygraph_params/
checkpoints: ./data/saved_model/mnli_models
epoch: 3
learning_rate: 5e-05
lr_scheduler: linear_warmup_decay
weight_decay: 0.01
warmup_proportion: 0.1
save_steps: 1000
validation_steps: 100
loss_scaling: 1.0
skip_steps: 10
data_dir: ./data/glue_data/MNLI/
vocab_path: ./data/pretrained_models/uncased_L-12_H-768_A-12//vocab.txt
max_seq_len: 128
batch_size: 64
in_tokens: False
do_lower_case: True
random_seed: 5512
use_cuda: True
shuffle: True
do_train: True
do_test: True
use_data_parallel: False
verbose: False
----------------------------------------------------------------------
attention_probs_dropout_prob: 0.1
hidden_act: gelu
hidden_dropout_prob: 0.1
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_hidden_layers: 12
type_vocab_size: 2
vocab_size: 30522
------------------------------------------------
Trainer count: 1
Num train examples: 392703
Max train steps: 18407
Num warmup steps: 1840
Epoch 1/3
step 10/12272 - loss: 1.1000 - acc_top1: 0.3531 - acc_top2: 0.6813 - 1s/step
step 20/12272 - loss: 1.1878 - acc_top1: 0.3578 - acc_top2: 0.6875 - 1s/step
step 30/12272 - loss: 1.0812 - acc_top1: 0.3708 - acc_top2: 0.6948 - 1s/step
step 40/12272 - loss: 1.1244 - acc_top1: 0.3773 - acc_top2: 0.6992 - 1s/step
step 50/12272 - loss: 1.1202 - acc_top1: 0.3756 - acc_top2: 0.7006 - 1s/step
step 60/12272 - loss: 1.1291 - acc_top1: 0.3703 - acc_top2: 0.6990 - 1s/step
step 70/12272 - loss: 1.0991 - acc_top1: 0.3634 - acc_top2: 0.6946 - 1s/step
step 80/12272 - loss: 1.0988 - acc_top1: 0.3602 - acc_top2: 0.6914 - 1s/step
step 90/12272 - loss: 1.0718 - acc_top1: 0.3646 - acc_top2: 0.6889 - 1s/step
step 100/12272 - loss: 1.0949 - acc_top1: 0.3638 - acc_top2: 0.6878 - 1s/step
step 110/12272 - loss: 1.1120 - acc_top1: 0.3608 - acc_top2: 0.6895 - 1s/step
step 120/12272 - loss: 1.1105 - acc_top1: 0.3622 - acc_top2: 0.6922 - 1s/step
step 130/12272 - loss: 1.0958 - acc_top1: 0.3623 - acc_top2: 0.6940 - 1s/step
step 140/12272 - loss: 1.0995 - acc_top1: 0.3636 - acc_top2: 0.6926 - 1s/step
step 150/12272 - loss: 1.1272 - acc_top1: 0.3671 - acc_top2: 0.6950 - 1s/step
step 160/12272 - loss: 1.0850 - acc_top1: 0.3697 - acc_top2: 0.6975 - 1s/step
step 170/12272 - loss: 1.0607 - acc_top1: 0.3691 - acc_top2: 0.6991 - 1s/step
step 180/12272 - loss: 1.0623 - acc_top1: 0.3707 - acc_top2: 0.6991 - 1s/step
step 190/12272 - loss: 1.1092 - acc_top1: 0.3697 - acc_top2: 0.6997 - 1s/step
step 200/12272 - loss: 1.1046 - acc_top1: 0.3713 - acc_top2: 0.7030 - 1s/step
step 210/12272 - loss: 1.0945 - acc_top1: 0.3720 - acc_top2: 0.7043 - 1s/step
step 220/12272 - loss: 1.0935 - acc_top1: 0.3719 - acc_top2: 0.7051 - 1s/step
step 230/12272 - loss: 1.1567 - acc_top1: 0.3742 - acc_top2: 0.7048 - 1s/step
step 240/12272 - loss: 1.0745 - acc_top1: 0.3766 - acc_top2: 0.7081 - 1s/step
step 250/12272 - loss: 1.0664 - acc_top1: 0.3756 - acc_top2: 0.7090 - 1s/step
step 260/12272 - loss: 1.0770 - acc_top1: 0.3751 - acc_top2: 0.7085 - 1s/step
step 270/12272 - loss: 1.1008 - acc_top1: 0.3730 - acc_top2: 0.7088 - 1s/step
step 280/12272 - loss: 1.0850 - acc_top1: 0.3737 - acc_top2: 0.7098 - 1s/step
step 290/12272 - loss: 1.0759 - acc_top1: 0.3747 - acc_top2: 0.7100 - 1s/step
step 300/12272 - loss: 1.0352 - acc_top1: 0.3758 - acc_top2: 0.7108 - 1s/step
step 310/12272 - loss: 1.0224 - acc_top1: 0.3786 - acc_top2: 0.7127 - 1s/step
step 320/12272 - loss: 1.0919 - acc_top1: 0.3800 - acc_top2: 0.7137 - 1s/step
step 330/12272 - loss: 1.0884 - acc_top1: 0.3825 - acc_top2: 0.7145 - 1s/step
step 340/12272 - loss: 1.1380 - acc_top1: 0.3849 - acc_top2: 0.7157 - 1s/step
step 350/12272 - loss: 0.9523 - acc_top1: 0.3890 - acc_top2: 0.7176 - 1s/step
step 360/12272 - loss: 0.9963 - acc_top1: 0.3922 - acc_top2: 0.7191 - 1s/step
step 370/12272 - loss: 1.1187 - acc_top1: 0.3955 - acc_top2: 0.7205 - 1s/step
step 380/12272 - loss: 0.9634 - acc_top1: 0.3988 - acc_top2: 0.7229 - 1s/step
step 390/12272 - loss: 0.9944 - acc_top1: 0.4017 - acc_top2: 0.7254 - 1s/step
step 400/12272 - loss: 1.1071 - acc_top1: 0.4044 - acc_top2: 0.7272 - 1s/step
step 410/12272 - loss: 0.9307 - acc_top1: 0.4070 - acc_top2: 0.7293 - 1s/step
step 420/12272 - loss: 1.1307 - acc_top1: 0.4087 - acc_top2: 0.7315 - 1s/step
step 430/12272 - loss: 0.9936 - acc_top1: 0.4110 - acc_top2: 0.7334 - 1s/step
step 440/12272 - loss: 0.9791 - acc_top1: 0.4139 - acc_top2: 0.7357 - 1s/step
step 450/12272 - loss: 1.0112 - acc_top1: 0.4147 - acc_top2: 0.7372 - 1s/step
step 460/12272 - loss: 0.8554 - acc_top1: 0.4179 - acc_top2: 0.7395 - 1s/step
step 470/12272 - loss: 0.9411 - acc_top1: 0.4198 - acc_top2: 0.7406 - 1s/step
step 480/12272 - loss: 0.8481 - acc_top1: 0.4231 - acc_top2: 0.7424 - 1s/step
step 490/12272 - loss: 1.0338 - acc_top1: 0.4261 - acc_top2: 0.7441 - 1s/step
step 500/12272 - loss: 0.9651 - acc_top1: 0.4281 - acc_top2: 0.7459 - 1s/step
step 510/12272 - loss: 0.8091 - acc_top1: 0.4306 - acc_top2: 0.7479 - 1s/step
step 520/12272 - loss: 1.0528 - acc_top1: 0.4325 - acc_top2: 0.7489 - 1s/step
step 530/12272 - loss: 0.9898 - acc_top1: 0.4338 - acc_top2: 0.7500 - 1s/step
step 540/12272 - loss: 0.7900 - acc_top1: 0.4364 - acc_top2: 0.7519 - 1s/step
step 550/12272 - loss: 0.9055 - acc_top1: 0.4389 - acc_top2: 0.7534 - 1s/step
step 560/12272 - loss: 1.0092 - acc_top1: 0.4410 - acc_top2: 0.7549 - 1s/step
step 570/12272 - loss: 0.7068 - acc_top1: 0.4441 - acc_top2: 0.7570 - 1s/step
step 580/12272 - loss: 0.9695 - acc_top1: 0.4455 - acc_top2: 0.7581 - 1s/step
step 590/12272 - loss: 0.8640 - acc_top1: 0.4487 - acc_top2: 0.7600 - 1s/step
step 600/12272 - loss: 0.9068 - acc_top1: 0.4514 - acc_top2: 0.7618 - 1s/step
step 610/12272 - loss: 0.9023 - acc_top1: 0.4524 - acc_top2: 0.7627 - 1s/step
step 620/12272 - loss: 0.7377 - acc_top1: 0.4552 - acc_top2: 0.7640 - 1s/step
step 630/12272 - loss: 0.8900 - acc_top1: 0.4574 - acc_top2: 0.7659 - 1s/step
step 640/12272 - loss: 0.8902 - acc_top1: 0.4590 - acc_top2: 0.7669 - 1s/step
step 650/12272 - loss: 0.9069 - acc_top1: 0.4608 - acc_top2: 0.7686 - 1s/step
step 660/12272 - loss: 0.9630 - acc_top1: 0.4631 - acc_top2: 0.7699 - 1s/step
step 670/12272 - loss: 0.9005 - acc_top1: 0.4652 - acc_top2: 0.7712 - 1s/step
step 680/12272 - loss: 1.0725 - acc_top1: 0.4670 - acc_top2: 0.7725 - 1s/step
step 690/12272 - loss: 0.8322 - acc_top1: 0.4689 - acc_top2: 0.7739 - 1s/step
step 700/12272 - loss: 0.9874 - acc_top1: 0.4714 - acc_top2: 0.7753 - 1s/step
step 710/12272 - loss: 0.7915 - acc_top1: 0.4728 - acc_top2: 0.7765 - 1s/step
step 720/12272 - loss: 0.7174 - acc_top1: 0.4746 - acc_top2: 0.7777 - 1s/step
step 730/12272 - loss: 0.7635 - acc_top1: 0.4770 - acc_top2: 0.7793 - 1s/step
step 740/12272 - loss: 0.9180 - acc_top1: 0.4793 - acc_top2: 0.7804 - 1s/step
step 750/12272 - loss: 0.8424 - acc_top1: 0.4817 - acc_top2: 0.7815 - 1s/step
step 760/12272 - loss: 0.9357 - acc_top1: 0.4837 - acc_top2: 0.7829 - 1s/step
step 770/12272 - loss: 0.7643 - acc_top1: 0.4858 - acc_top2: 0.7839 - 1s/step
step 780/12272 - loss: 0.8910 - acc_top1: 0.4868 - acc_top2: 0.7849 - 1s/step
step 790/12272 - loss: 0.8781 - acc_top1: 0.4888 - acc_top2: 0.7862 - 1s/step
step 800/12272 - loss: 0.8005 - acc_top1: 0.4907 - acc_top2: 0.7877 - 1s/step
step 810/12272 - loss: 0.6740 - acc_top1: 0.4929 - acc_top2: 0.7889 - 1s/step
step 820/12272 - loss: 0.7026 - acc_top1: 0.4947 - acc_top2: 0.7898 - 1s/step
step 830/12272 - loss: 0.8666 - acc_top1: 0.4964 - acc_top2: 0.7908 - 1s/step
step 840/12272 - loss: 0.6296 - acc_top1: 0.4983 - acc_top2: 0.7920 - 1s/step
step 850/12272 - loss: 0.7907 - acc_top1: 0.4992 - acc_top2: 0.7930 - 1s/step
step 860/12272 - loss: 0.7292 - acc_top1: 0.5007 - acc_top2: 0.7935 - 1s/step
step 870/12272 - loss: 0.7498 - acc_top1: 0.5026 - acc_top2: 0.7944 - 1s/step
step 880/12272 - loss: 0.9928 - acc_top1: 0.5040 - acc_top2: 0.7953 - 1s/step
step 890/12272 - loss: 1.0025 - acc_top1: 0.5056 - acc_top2: 0.7962 - 1s/step
step 900/12272 - loss: 0.7810 - acc_top1: 0.5071 - acc_top2: 0.7969 - 1s/step
step 910/12272 - loss: 0.6114 - acc_top1: 0.5090 - acc_top2: 0.7978 - 1s/step
step 920/12272 - loss: 0.7780 - acc_top1: 0.5105 - acc_top2: 0.7988 - 1s/step
step 930/12272 - loss: 0.9457 - acc_top1: 0.5116 - acc_top2: 0.7995 - 1s/step
step 940/12272 - loss: 0.7907 - acc_top1: 0.5135 - acc_top2: 0.8006 - 1s/step
step 950/12272 - loss: 0.5520 - acc_top1: 0.5153 - acc_top2: 0.8013 - 1s/step
step 960/12272 - loss: 0.8251 - acc_top1: 0.5168 - acc_top2: 0.8022 - 1s/step
step 970/12272 - loss: 0.8482 - acc_top1: 0.5179 - acc_top2: 0.8031 - 1s/step
step 980/12272 - loss: 0.8010 - acc_top1: 0.5196 - acc_top2: 0.8038 - 1s/step
step 990/12272 - loss: 0.8326 - acc_top1: 0.5207 - acc_top2: 0.8047 - 1s/step
step 1000/12272 - loss: 0.6979 - acc_top1: 0.5222 - acc_top2: 0.8057 - 1s/step
step 1010/12272 - loss: 0.7506 - acc_top1: 0.5234 - acc_top2: 0.8065 - 1s/step
step 1020/12272 - loss: 0.8457 - acc_top1: 0.5248 - acc_top2: 0.8073 - 1s/step
step 1030/12272 - loss: 0.8698 - acc_top1: 0.5263 - acc_top2: 0.8082 - 1s/step
step 1040/12272 - loss: 0.7016 - acc_top1: 0.5279 - acc_top2: 0.8091 - 1s/step
step 1050/12272 - loss: 0.7766 - acc_top1: 0.5290 - acc_top2: 0.8099 - 1s/step
step 1060/12272 - loss: 0.7994 - acc_top1: 0.5300 - acc_top2: 0.8105 - 1s/step
step 1070/12272 - loss: 0.7053 - acc_top1: 0.5317 - acc_top2: 0.8115 - 1s/step
step 1080/12272 - loss: 0.9085 - acc_top1: 0.5330 - acc_top2: 0.8125 - 1s/step
step 1090/12272 - loss: 0.7556 - acc_top1: 0.5342 - acc_top2: 0.8134 - 1s/step
step 1100/12272 - loss: 0.9364 - acc_top1: 0.5355 - acc_top2: 0.8141 - 1s/step
step 1110/12272 - loss: 0.9403 - acc_top1: 0.5367 - acc_top2: 0.8148 - 1s/step
step 1120/12272 - loss: 0.8228 - acc_top1: 0.5375 - acc_top2: 0.8152 - 1s/step
step 1130/12272 - loss: 0.6802 - acc_top1: 0.5388 - acc_top2: 0.8160 - 1s/step
step 1140/12272 - loss: 0.8222 - acc_top1: 0.5397 - acc_top2: 0.8167 - 1s/step
step 1150/12272 - loss: 0.9321 - acc_top1: 0.5407 - acc_top2: 0.8172 - 1s/step
step 1160/12272 - loss: 0.7478 - acc_top1: 0.5417 - acc_top2: 0.8181 - 1s/step
step 1170/12272 - loss: 0.7976 - acc_top1: 0.5430 - acc_top2: 0.8188 - 1s/step
step 1180/12272 - loss: 0.7386 - acc_top1: 0.5441 - acc_top2: 0.8192 - 1s/step
step 1190/12272 - loss: 0.6448 - acc_top1: 0.5450 - acc_top2: 0.8200 - 1s/step
step 1200/12272 - loss: 0.7441 - acc_top1: 0.5463 - acc_top2: 0.8206 - 1s/step
step 1210/12272 - loss: 0.8171 - acc_top1: 0.5476 - acc_top2: 0.8213 - 1s/step
step 1220/12272 - loss: 0.7480 - acc_top1: 0.5487 - acc_top2: 0.8219 - 1s/step
step 1230/12272 - loss: 0.6363 - acc_top1: 0.5497 - acc_top2: 0.8225 - 1s/step
step 1240/12272 - loss: 0.6630 - acc_top1: 0.5507 - acc_top2: 0.8231 - 1s/step
step 1250/12272 - loss: 0.8668 - acc_top1: 0.5517 - acc_top2: 0.8237 - 1s/step
step 1260/12272 - loss: 0.6057 - acc_top1: 0.5527 - acc_top2: 0.8243 - 1s/step
step 1270/12272 - loss: 0.8432 - acc_top1: 0.5538 - acc_top2: 0.8248 - 1s/step
step 1280/12272 - loss: 0.8447 - acc_top1: 0.5546 - acc_top2: 0.8253 - 1s/step
step 1290/12272 - loss: 0.6928 - acc_top1: 0.5556 - acc_top2: 0.8261 - 1s/step
step 1300/12272 - loss: 0.7872 - acc_top1: 0.5567 - acc_top2: 0.8266 - 1s/step
step 1310/12272 - loss: 0.7968 - acc_top1: 0.5570 - acc_top2: 0.8269 - 1s/step
step 1320/12272 - loss: 0.8059 - acc_top1: 0.5580 - acc_top2: 0.8275 - 1s/step
step 1330/12272 - loss: 0.8603 - acc_top1: 0.5587 - acc_top2: 0.8278 - 1s/step
step 1340/12272 - loss: 0.7872 - acc_top1: 0.5599 - acc_top2: 0.8285 - 1s/step
step 1350/12272 - loss: 0.7037 - acc_top1: 0.5609 - acc_top2: 0.8290 - 1s/step
step 1360/12272 - loss: 0.8268 - acc_top1: 0.5618 - acc_top2: 0.8297 - 1s/step
step 1370/12272 - loss: 0.5962 - acc_top1: 0.5627 - acc_top2: 0.8303 - 1s/step
step 1380/12272 - loss: 0.7712 - acc_top1: 0.5638 - acc_top2: 0.8310 - 1s/step
step 1390/12272 - loss: 0.5770 - acc_top1: 0.5650 - acc_top2: 0.8315 - 1s/step
step 1400/12272 - loss: 0.7174 - acc_top1: 0.5656 - acc_top2: 0.8319 - 1s/step
step 1410/12272 - loss: 0.6224 - acc_top1: 0.5660 - acc_top2: 0.8323 - 1s/step
step 1420/12272 - loss: 0.6782 - acc_top1: 0.5671 - acc_top2: 0.8328 - 1s/step
step 1430/12272 - loss: 0.4087 - acc_top1: 0.5682 - acc_top2: 0.8335 - 1s/step
step 1440/12272 - loss: 0.7534 - acc_top1: 0.5692 - acc_top2: 0.8342 - 1s/step
step 1450/12272 - loss: 0.6446 - acc_top1: 0.5702 - acc_top2: 0.8345 - 1s/step
step 1460/12272 - loss: 0.6606 - acc_top1: 0.5712 - acc_top2: 0.8351 - 1s/step
step 1470/12272 - loss: 0.7308 - acc_top1: 0.5723 - acc_top2: 0.8357 - 1s/step
step 1480/12272 - loss: 0.9016 - acc_top1: 0.5727 - acc_top2: 0.8359 - 1s/step
step 1490/12272 - loss: 0.8445 - acc_top1: 0.5730 - acc_top2: 0.8362 - 1s/step
step 1500/12272 - loss: 0.8217 - acc_top1: 0.5737 - acc_top2: 0.8367 - 1s/step
step 1510/12272 - loss: 0.8413 - acc_top1: 0.5747 - acc_top2: 0.8370 - 1s/step
step 1520/12272 - loss: 0.4643 - acc_top1: 0.5757 - acc_top2: 0.8376 - 1s/step
step 1530/12272 - loss: 0.9351 - acc_top1: 0.5764 - acc_top2: 0.8381 - 1s/step
step 1540/12272 - loss: 0.7856 - acc_top1: 0.5773 - acc_top2: 0.8386 - 1s/step
step 1550/12272 - loss: 0.5921 - acc_top1: 0.5780 - acc_top2: 0.8390 - 1s/step
step 1560/12272 - loss: 0.4460 - acc_top1: 0.5788 - acc_top2: 0.8395 - 1s/step
step 1570/12272 - loss: 0.6814 - acc_top1: 0.5793 - acc_top2: 0.8401 - 1s/step
step 1580/12272 - loss: 0.4115 - acc_top1: 0.5805 - acc_top2: 0.8407 - 1s/step
step 1590/12272 - loss: 0.9326 - acc_top1: 0.5810 - acc_top2: 0.8410 - 1s/step
step 1600/12272 - loss: 0.6989 - acc_top1: 0.5818 - acc_top2: 0.8413 - 1s/step
step 1610/12272 - loss: 0.5238 - acc_top1: 0.5826 - acc_top2: 0.8418 - 1s/step
step 1620/12272 - loss: 0.5827 - acc_top1: 0.5832 - acc_top2: 0.8422 - 1s/step
step 1630/12272 - loss: 0.7703 - acc_top1: 0.5838 - acc_top2: 0.8425 - 1s/step
step 1640/12272 - loss: 0.7926 - acc_top1: 0.5844 - acc_top2: 0.8428 - 1s/step
step 1650/12272 - loss: 0.7143 - acc_top1: 0.5851 - acc_top2: 0.8434 - 1s/step
step 1660/12272 - loss: 0.6240 - acc_top1: 0.5858 - acc_top2: 0.8438 - 1s/step
step 1670/12272 - loss: 0.7869 - acc_top1: 0.5862 - acc_top2: 0.8440 - 1s/step
step 1680/12272 - loss: 0.6485 - acc_top1: 0.5868 - acc_top2: 0.8444 - 1s/step
step 1690/12272 - loss: 0.7539 - acc_top1: 0.5876 - acc_top2: 0.8450 - 1s/step
step 1700/12272 - loss: 0.6173 - acc_top1: 0.5882 - acc_top2: 0.8454 - 1s/step
step 1710/12272 - loss: 0.8056 - acc_top1: 0.5890 - acc_top2: 0.8458 - 1s/step
step 1720/12272 - loss: 0.7035 - acc_top1: 0.5898 - acc_top2: 0.8463 - 1s/step
step 1730/12272 - loss: 0.5892 - acc_top1: 0.5908 - acc_top2: 0.8468 - 1s/step
step 1740/12272 - loss: 0.7755 - acc_top1: 0.5915 - acc_top2: 0.8472 - 1s/step
step 1750/12272 - loss: 0.6911 - acc_top1: 0.5920 - acc_top2: 0.8474 - 1s/step
step 1760/12272 - loss: 0.6309 - acc_top1: 0.5926 - acc_top2: 0.8477 - 1s/step
step 1770/12272 - loss: 0.7506 - acc_top1: 0.5932 - acc_top2: 0.8480 - 1s/step
step 1780/12272 - loss: 0.8711 - acc_top1: 0.5939 - acc_top2: 0.8482 - 1s/step
step 1790/12272 - loss: 0.9146 - acc_top1: 0.5945 - acc_top2: 0.8484 - 1s/step
step 1800/12272 - loss: 0.6208 - acc_top1: 0.5952 - acc_top2: 0.8487 - 1s/step
step 1810/12272 - loss: 0.8506 - acc_top1: 0.5959 - acc_top2: 0.8490 - 1s/step
step 1820/12272 - loss: 0.8330 - acc_top1: 0.5965 - acc_top2: 0.8494 - 1s/step
step 1830/12272 - loss: 0.8315 - acc_top1: 0.5970 - acc_top2: 0.8497 - 1s/step
step 1840/12272 - loss: 0.6227 - acc_top1: 0.5977 - acc_top2: 0.8501 - 1s/step
step 1850/12272 - loss: 0.5972 - acc_top1: 0.5985 - acc_top2: 0.8506 - 1s/step
step 1860/12272 - loss: 0.6309 - acc_top1: 0.5992 - acc_top2: 0.8510 - 1s/step
step 1870/12272 - loss: 0.8707 - acc_top1: 0.5995 - acc_top2: 0.8512 - 1s/step
step 1880/12272 - loss: 0.6419 - acc_top1: 0.6004 - acc_top2: 0.8516 - 1s/step
step 1890/12272 - loss: 0.6015 - acc_top1: 0.6010 - acc_top2: 0.8521 - 1s/step
step 1900/12272 - loss: 0.6000 - acc_top1: 0.6015 - acc_top2: 0.8524 - 1s/step
step 1910/12272 - loss: 0.7010 - acc_top1: 0.6020 - acc_top2: 0.8527 - 1s/step
step 1920/12272 - loss: 0.8539 - acc_top1: 0.6026 - acc_top2: 0.8530 - 1s/step
step 1930/12272 - loss: 0.8381 - acc_top1: 0.6031 - acc_top2: 0.8533 - 1s/step
step 1940/12272 - loss: 0.5921 - acc_top1: 0.6039 - acc_top2: 0.8537 - 1s/step
step 1950/12272 - loss: 0.4974 - acc_top1: 0.6047 - acc_top2: 0.8541 - 1s/step
step 1960/12272 - loss: 0.8269 - acc_top1: 0.6052 - acc_top2: 0.8544 - 1s/step
step 1970/12272 - loss: 0.6157 - acc_top1: 0.6058 - acc_top2: 0.8548 - 1s/step
step 1980/12272 - loss: 1.0949 - acc_top1: 0.6064 - acc_top2: 0.8552 - 1s/step
step 1990/12272 - loss: 0.6442 - acc_top1: 0.6070 - acc_top2: 0.8555 - 1s/step
step 2000/12272 - loss: 0.8747 - acc_top1: 0.6073 - acc_top2: 0.8558 - 1s/step
step 2010/12272 - loss: 0.8101 - acc_top1: 0.6078 - acc_top2: 0.8560 - 1s/step
step 2020/12272 - loss: 0.8623 - acc_top1: 0.6082 - acc_top2: 0.8562 - 1s/step
step 2030/12272 - loss: 0.6664 - acc_top1: 0.6089 - acc_top2: 0.8567 - 1s/step
step 2040/12272 - loss: 0.7616 - acc_top1: 0.6092 - acc_top2: 0.8567 - 1s/step
step 2050/12272 - loss: 0.7282 - acc_top1: 0.6095 - acc_top2: 0.8570 - 1s/step
step 2060/12272 - loss: 0.6914 - acc_top1: 0.6099 - acc_top2: 0.8574 - 1s/step
step 2070/12272 - loss: 0.6129 - acc_top1: 0.6105 - acc_top2: 0.8577 - 1s/step
step 2080/12272 - loss: 0.5605 - acc_top1: 0.6111 - acc_top2: 0.8580 - 1s/step
step 2090/12272 - loss: 0.6432 - acc_top1: 0.6116 - acc_top2: 0.8582 - 1s/step
step 2100/12272 - loss: 0.6783 - acc_top1: 0.6121 - acc_top2: 0.8586 - 1s/step
step 2110/12272 - loss: 0.5949 - acc_top1: 0.6128 - acc_top2: 0.8589 - 1s/step
step 2120/12272 - loss: 0.7832 - acc_top1: 0.6134 - acc_top2: 0.8592 - 1s/step
step 2130/12272 - loss: 0.6633 - acc_top1: 0.6139 - acc_top2: 0.8594 - 1s/step
step 2140/12272 - loss: 0.8456 - acc_top1: 0.6143 - acc_top2: 0.8596 - 1s/step
step 2150/12272 - loss: 0.7133 - acc_top1: 0.6150 - acc_top2: 0.8599 - 1s/step
step 2160/12272 - loss: 0.4699 - acc_top1: 0.6155 - acc_top2: 0.8602 - 1s/step
step 2170/12272 - loss: 0.6013 - acc_top1: 0.6161 - acc_top2: 0.8605 - 1s/step
step 2180/12272 - loss: 0.5676 - acc_top1: 0.6165 - acc_top2: 0.8608 - 1s/step
step 2190/12272 - loss: 0.5850 - acc_top1: 0.6172 - acc_top2: 0.8611 - 1s/step
step 2200/12272 - loss: 0.6887 - acc_top1: 0.6177 - acc_top2: 0.8612 - 1s/step
step 2210/12272 - loss: 0.5706 - acc_top1: 0.6180 - acc_top2: 0.8614 - 1s/step
step 2220/12272 - loss: 0.8251 - acc_top1: 0.6184 - acc_top2: 0.8617 - 1s/step
step 2230/12272 - loss: 0.6532 - acc_top1: 0.6188 - acc_top2: 0.8620 - 1s/step
step 2240/12272 - loss: 0.5888 - acc_top1: 0.6194 - acc_top2: 0.8623 - 1s/step
step 2250/12272 - loss: 0.6360 - acc_top1: 0.6198 - acc_top2: 0.8625 - 1s/step
step 2260/12272 - loss: 1.0555 - acc_top1: 0.6202 - acc_top2: 0.8628 - 1s/step
step 2270/12272 - loss: 0.4848 - acc_top1: 0.6207 - acc_top2: 0.8629 - 1s/step
step 2280/12272 - loss: 0.7243 - acc_top1: 0.6212 - acc_top2: 0.8632 - 1s/step
step 2290/12272 - loss: 0.4358 - acc_top1: 0.6216 - acc_top2: 0.8635 - 1s/step
step 2300/12272 - loss: 0.5473 - acc_top1: 0.6221 - acc_top2: 0.8637 - 1s/step
step 2310/12272 - loss: 0.6440 - acc_top1: 0.6226 - acc_top2: 0.8640 - 1s/step
step 2320/12272 - loss: 0.5785 - acc_top1: 0.6233 - acc_top2: 0.8643 - 1s/step
step 2330/12272 - loss: 0.7199 - acc_top1: 0.6237 - acc_top2: 0.8646 - 1s/step
step 2340/12272 - loss: 0.5622 - acc_top1: 0.6241 - acc_top2: 0.8647 - 1s/step
step 2350/12272 - loss: 0.6742 - acc_top1: 0.6245 - acc_top2: 0.8650 - 1s/step
step 2360/12272 - loss: 0.8149 - acc_top1: 0.6249 - acc_top2: 0.8652 - 1s/step
step 2370/12272 - loss: 0.5900 - acc_top1: 0.6253 - acc_top2: 0.8654 - 1s/step
step 2380/12272 - loss: 0.8046 - acc_top1: 0.6256 - acc_top2: 0.8656 - 1s/step
step 2390/12272 - loss: 0.6097 - acc_top1: 0.6262 - acc_top2: 0.8659 - 1s/step
step 2400/12272 - loss: 0.5936 - acc_top1: 0.6266 - acc_top2: 0.8660 - 1s/step
step 2410/12272 - loss: 0.7245 - acc_top1: 0.6270 - acc_top2: 0.8662 - 1s/step
step 2420/12272 - loss: 0.6349 - acc_top1: 0.6274 - acc_top2: 0.8665 - 1s/step
step 2430/12272 - loss: 0.7009 - acc_top1: 0.6278 - acc_top2: 0.8668 - 1s/step
step 2440/12272 - loss: 0.3881 - acc_top1: 0.6282 - acc_top2: 0.8670 - 1s/step
step 2450/12272 - loss: 0.5226 - acc_top1: 0.6286 - acc_top2: 0.8673 - 1s/step
step 2460/12272 - loss: 0.5748 - acc_top1: 0.6292 - acc_top2: 0.8675 - 1s/step
step 2470/12272 - loss: 0.4798 - acc_top1: 0.6297 - acc_top2: 0.8678 - 1s/step
step 2480/12272 - loss: 0.5857 - acc_top1: 0.6303 - acc_top2: 0.8680 - 1s/step
step 2490/12272 - loss: 0.6729 - acc_top1: 0.6308 - acc_top2: 0.8683 - 1s/step
step 2500/12272 - loss: 0.6392 - acc_top1: 0.6312 - acc_top2: 0.8686 - 1s/step
step 2510/12272 - loss: 0.9607 - acc_top1: 0.6315 - acc_top2: 0.8687 - 1s/step
step 2520/12272 - loss: 0.6036 - acc_top1: 0.6319 - acc_top2: 0.8690 - 1s/step
step 2530/12272 - loss: 0.6505 - acc_top1: 0.6324 - acc_top2: 0.8693 - 1s/step
step 2540/12272 - loss: 0.4558 - acc_top1: 0.6329 - acc_top2: 0.8696 - 1s/step
step 2550/12272 - loss: 0.4215 - acc_top1: 0.6333 - acc_top2: 0.8699 - 1s/step
step 2560/12272 - loss: 0.6908 - acc_top1: 0.6338 - acc_top2: 0.8701 - 1s/step
step 2570/12272 - loss: 0.5833 - acc_top1: 0.6342 - acc_top2: 0.8703 - 1s/step
step 2580/12272 - loss: 0.8548 - acc_top1: 0.6346 - acc_top2: 0.8706 - 1s/step
step 2590/12272 - loss: 0.5770 - acc_top1: 0.6351 - acc_top2: 0.8708 - 1s/step
step 2600/12272 - loss: 0.4476 - acc_top1: 0.6355 - acc_top2: 0.8711 - 1s/step
step 2610/12272 - loss: 0.4145 - acc_top1: 0.6360 - acc_top2: 0.8714 - 1s/step
step 2620/12272 - loss: 0.6625 - acc_top1: 0.6365 - acc_top2: 0.8717 - 1s/step
step 2630/12272 - loss: 0.4808 - acc_top1: 0.6369 - acc_top2: 0.8719 - 1s/step
examples/bert_leveldb/run_classifier_single_gpu.sh
0 → 100755
浏览文件 @
1373e294
#!/bin/bash
BERT_BASE_PATH
=
"./data/pretrained_models/uncased_L-12_H-768_A-12/"
TASK_NAME
=
'MNLI'
DATA_PATH
=
"./data/glue_data/MNLI/"
CKPT_PATH
=
"./data/saved_model/mnli_models"
export
CUDA_VISIBLE_DEVICES
=
7
# start fine-tuning
python3.7 bert_classifier.py
\
--use_cuda
true
\
--do_train
true
\
--do_test
true
\
--batch_size
64
\
--init_pretraining_params
${
BERT_BASE_PATH
}
/dygraph_params/
\
--data_dir
${
DATA_PATH
}
\
--vocab_path
${
BERT_BASE_PATH
}
/vocab.txt
\
--checkpoints
${
CKPT_PATH
}
\
--save_steps
1000
\
--weight_decay
0.01
\
--warmup_proportion
0.1
\
--validation_steps
100
\
--epoch
3
\
--max_seq_len
128
\
--bert_config_path
${
BERT_BASE_PATH
}
/bert_config.json
\
--learning_rate
5e-5
\
--skip_steps
10
\
--shuffle
true
hapi/text/bert/dataloader.py
浏览文件 @
1373e294
...
@@ -19,6 +19,7 @@ import csv
...
@@ -19,6 +19,7 @@ import csv
import
glob
import
glob
import
tarfile
import
tarfile
import
itertools
import
itertools
import
leveldb
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
...
@@ -167,10 +168,14 @@ class SingleSentenceDataset(Dataset):
...
@@ -167,10 +168,14 @@ class SingleSentenceDataset(Dataset):
assert
isinstance
(
mode
,
assert
isinstance
(
mode
,
str
),
"mode of SingleSentenceDataset should be str"
str
),
"mode of SingleSentenceDataset should be str"
assert
mode
in
[
assert
mode
in
[
"all_in_memory"
,
"leveldb"
"all_in_memory"
,
"leveldb"
,
"streaming"
],
"mode of SingleSentenceDataset should be in [all_in_memory, leveldb], but get"
%
mode
],
"mode of SingleSentenceDataset should be in [all_in_memory, leveldb
, streaming
], but get"
%
mode
self
.
delimiter
=
None
self
.
mode
=
mode
self
.
examples
=
[]
self
.
examples
=
[]
self
.
_db
=
None
self
.
_line_processor
=
None
def
load_all_data_in_memory
(
self
,
def
load_all_data_in_memory
(
self
,
input_file
,
input_file
,
...
@@ -202,13 +207,87 @@ class SingleSentenceDataset(Dataset):
...
@@ -202,13 +207,87 @@ class SingleSentenceDataset(Dataset):
tokenizer
)
tokenizer
)
self
.
examples
.
append
(
input_feature
)
self
.
examples
.
append
(
input_feature
)
def
prepare_leveldb
(
self
,
input_file
,
leveldb_file
,
label_list
,
max_seq_length
,
tokenizer
,
line_processor
=
None
,
delimiter
=
"
\t
"
,
quotechar
=
None
):
def
default_line_processor
(
line_id
,
line
):
assert
len
(
line
)
==
2
text_a
=
line
[
0
]
label
=
line
[
1
]
return
BertInputExample
(
str
(
line_id
),
text_a
=
text_a
,
text_b
=
None
,
label
=
label
)
if
line_processor
is
None
:
line_processor
=
default_line_processor
if
not
os
.
path
.
exists
(
leveldb_file
):
print
(
"putting data %s into leveldb %s"
%
(
input_file
,
leveldb_file
))
_example_num
=
0
_db
=
leveldb
.
LevelDB
(
leveldb_file
,
create_if_missing
=
True
)
with
io
.
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
delimiter
,
quotechar
=
quotechar
)
line_id
=
0
for
(
_line_id
,
line
)
in
enumerate
(
reader
):
if
line_processor
(
str
(
_line_id
),
line
)
is
None
:
continue
line_str
=
delimiter
.
join
(
line
)
_db
.
Put
(
str
(
line_id
).
encode
(
"utf8"
),
line_str
.
encode
(
"utf8"
))
line_id
+=
1
_example_num
+=
1
_db
.
Put
(
"_example_num_"
.
encode
(
"utf8"
),
str
(
_example_num
).
encode
(
"utf8"
))
else
:
_db
=
leveldb
.
LevelDB
(
leveldb_file
,
create_if_missing
=
False
)
self
.
label_list
=
label_list
self
.
max_seq_length
=
max_seq_length
self
.
tokenizer
=
tokenizer
self
.
delimiter
=
delimiter
self
.
_db
=
_db
self
.
_line_processor
=
line_processor
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
self
.
examples
[
idx
].
input_ids
,
self
.
examples
[
idx
].
pos_ids
,
self
.
examples
[
idx
].
segment_ids
,
self
.
examples
[
if
self
.
mode
==
"all_in_memory"
:
idx
].
label_id
return
self
.
examples
[
idx
].
input_ids
,
self
.
examples
[
idx
].
pos_ids
,
self
.
examples
[
idx
].
segment_ids
,
self
.
examples
[
idx
].
label_id
if
self
.
mode
==
"leveldb"
:
assert
self
.
_db
is
not
None
,
"you shold call prepare_leveldb before you run dataloader"
line_str
=
self
.
_db
.
Get
(
str
(
idx
).
encode
(
"utf8"
))
line_str
=
line_str
.
decode
(
"utf8"
)
line
=
line_str
.
split
(
self
.
delimiter
)
input_example
=
self
.
_line_processor
(
str
(
idx
+
1
),
line
)
input_example
=
convert_single_example
(
str
(
idx
+
1
),
input_example
,
self
.
label_list
,
self
.
max_seq_length
,
self
.
tokenizer
)
return
input_example
.
input_ids
,
input_example
.
pos_ids
,
input_example
.
segment_ids
,
input_example
.
label_id
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
examples
)
if
self
.
mode
==
"all_in_memory"
:
return
len
(
self
.
examples
)
if
self
.
mode
==
"leveldb"
:
assert
self
.
_db
is
not
None
,
"you shold call prepare_leveldb before you run dataloader"
exmaple_num
=
self
.
_db
.
Get
(
"_example_num_"
.
encode
(
"utf8"
))
exmaple_num
=
exmaple_num
.
decode
(
"utf8"
)
return
int
(
exmaple_num
)
class
SentencePairDataset
(
Dataset
):
class
SentencePairDataset
(
Dataset
):
...
@@ -299,6 +378,7 @@ class BertDataLoader(object):
...
@@ -299,6 +378,7 @@ class BertDataLoader(object):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
mode
=
"all_in_memory"
,
mode
=
"all_in_memory"
,
leveldb_file
=
"./leveldb"
,
line_processor
=
None
,
line_processor
=
None
,
delimiter
=
"
\t
"
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
quotechar
=
None
,
...
@@ -314,8 +394,10 @@ class BertDataLoader(object):
...
@@ -314,8 +394,10 @@ class BertDataLoader(object):
input_file
,
label_list
,
max_seq_length
,
tokenizer
,
input_file
,
label_list
,
max_seq_length
,
tokenizer
,
line_processor
,
delimiter
,
quotechar
)
line_processor
,
delimiter
,
quotechar
)
elif
mode
==
"leveldb"
:
elif
mode
==
"leveldb"
:
#TODO add leveldb reader
#prepare_leveldb(self, input_file, leveldb_file, label_list, max_seq_length, tokenizer, line_processor=None, delimiter="\t", quotechar=None):
pass
self
.
dataset
.
prepare_leveldb
(
input_file
,
leveldb_file
,
label_list
,
max_seq_length
,
tokenizer
,
line_processor
,
delimiter
,
quotechar
)
else
:
else
:
raise
ValueError
(
"mode should be in [all_in_memory, leveldb]"
)
raise
ValueError
(
"mode should be in [all_in_memory, leveldb]"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录