Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
2bbf29ae
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2bbf29ae
编写于
3月 27, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add bert-reader
上级
490d9e4e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
1786 addition
and
0 deletion
+1786
-0
demo/bert-cls/reader/__init__.py
demo/bert-cls/reader/__init__.py
+0
-0
demo/bert-cls/reader/cls.py
demo/bert-cls/reader/cls.py
+573
-0
demo/bert-cls/reader/pretraining.py
demo/bert-cls/reader/pretraining.py
+290
-0
demo/bert-cls/reader/squad.py
demo/bert-cls/reader/squad.py
+923
-0
未找到文件。
demo/bert-cls/reader/__init__.py
0 → 100644
浏览文件 @
2bbf29ae
demo/bert-cls/reader/cls.py
0 → 100644
浏览文件 @
2bbf29ae
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
types
import
csv
import
numpy
as
np
import
tokenization
from
batching
import
prepare_batch_data
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
__init__
(
self
,
data_dir
,
vocab_path
,
max_seq_len
,
do_lower_case
,
in_tokens
,
random_seed
=
None
):
self
.
data_dir
=
data_dir
self
.
max_seq_len
=
max_seq_len
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_path
,
do_lower_case
=
do_lower_case
)
self
.
vocab
=
self
.
tokenizer
.
vocab
self
.
in_tokens
=
in_tokens
np
.
random
.
seed
(
random_seed
)
self
.
current_train_example
=
-
1
self
.
num_examples
=
{
'train'
:
-
1
,
'dev'
:
-
1
,
'test'
:
-
1
}
self
.
current_train_epoch
=
-
1
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_test_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for prediction."""
raise
NotImplementedError
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
def
convert_example
(
self
,
index
,
example
,
labels
,
max_seq_len
,
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
feature
=
convert_single_example
(
index
,
example
,
labels
,
max_seq_len
,
tokenizer
)
return
feature
def
generate_instance
(
self
,
feature
):
"""
generate instance with given feature
Args:
feature: InputFeatures(object). A single set of features of data.
"""
input_pos
=
list
(
range
(
len
(
feature
.
input_ids
)))
return
[
feature
.
input_ids
,
feature
.
segment_ids
,
input_pos
,
feature
.
label_id
]
def
generate_batch_data
(
self
,
batch_data
,
total_token_num
,
voc_size
=-
1
,
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
):
return
prepare_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
pad_id
=
self
.
vocab
[
"[PAD]"
],
cls_id
=
self
.
vocab
[
"[CLS]"
],
sep_id
=
self
.
vocab
[
"[SEP]"
],
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
lines
.
append
(
line
)
return
lines
def
get_num_examples
(
self
,
phase
):
"""Get number of examples for train, dev or test."""
if
phase
not
in
[
'train'
,
'dev'
,
'test'
]:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
return
self
.
num_examples
[
phase
]
def
get_train_progress
(
self
):
"""Gets progress for training phase."""
return
self
.
current_train_example
,
self
.
current_train_epoch
def
data_generator
(
self
,
batch_size
,
phase
=
'train'
,
epoch
=
1
,
shuffle
=
True
):
"""
Generate data for train, dev or test.
Args:
batch_size: int. The batch size of generated data.
phase: string. The phase for which to generate data.
epoch: int. Total epoches to generate data.
shuffle: bool. Whether to shuffle examples.
"""
if
phase
==
'train'
:
examples
=
self
.
get_train_examples
(
self
.
data_dir
)
self
.
num_examples
[
'train'
]
=
len
(
examples
)
elif
phase
==
'dev'
:
examples
=
self
.
get_dev_examples
(
self
.
data_dir
)
self
.
num_examples
[
'dev'
]
=
len
(
examples
)
elif
phase
==
'test'
:
examples
=
self
.
get_test_examples
(
self
.
data_dir
)
self
.
num_examples
[
'test'
]
=
len
(
examples
)
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
def
instance_reader
():
for
epoch_index
in
range
(
epoch
):
if
shuffle
:
np
.
random
.
shuffle
(
examples
)
if
phase
==
'train'
:
self
.
current_train_epoch
=
epoch_index
for
(
index
,
example
)
in
enumerate
(
examples
):
if
phase
==
'train'
:
self
.
current_train_example
=
index
+
1
feature
=
self
.
convert_example
(
index
,
example
,
self
.
get_labels
(),
self
.
max_seq_len
,
self
.
tokenizer
)
instance
=
self
.
generate_instance
(
feature
)
yield
instance
def
batch_reader
(
reader
,
batch_size
,
in_tokens
):
batch
,
total_token_num
,
max_len
=
[],
0
,
0
for
instance
in
reader
():
token_ids
,
sent_ids
,
pos_ids
,
label
=
instance
[:
4
]
max_len
=
max
(
max_len
,
len
(
token_ids
))
if
in_tokens
:
to_append
=
(
len
(
batch
)
+
1
)
*
max_len
<=
batch_size
else
:
to_append
=
len
(
batch
)
<
batch_size
if
to_append
:
batch
.
append
(
instance
)
total_token_num
+=
len
(
token_ids
)
else
:
yield
batch
,
total_token_num
batch
,
total_token_num
,
max_len
=
[
instance
],
len
(
token_ids
),
len
(
token_ids
)
if
len
(
batch
)
>
0
:
yield
batch
,
total_token_num
def
wrapper
():
for
batch_data
,
total_token_num
in
batch_reader
(
instance_reader
,
batch_size
,
self
.
in_tokens
):
batch_data
=
self
.
generate_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
yield
batch_data
return
wrapper
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
tokens_b
.
pop
()
class
InputFeatures
(
object
):
"""A single set of features of data."""
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
self
.
language
=
"zh"
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
self
.
language
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"train-%d"
%
(
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
0
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
tokenization
.
convert_to_unicode
(
line
[
2
])
if
label
==
tokenization
.
convert_to_unicode
(
"contradictory"
):
label
=
tokenization
.
convert_to_unicode
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
self
.
language
=
"zh"
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
(
i
)
language
=
tokenization
.
convert_to_unicode
(
line
[
0
])
if
language
!=
tokenization
.
convert_to_unicode
(
self
.
language
):
continue
text_a
=
tokenization
.
convert_to_unicode
(
line
[
6
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
7
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
self
.
language
=
"zh"
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
(
i
)
language
=
tokenization
.
convert_to_unicode
(
line
[
0
])
if
language
!=
tokenization
.
convert_to_unicode
(
self
.
language
):
continue
text_a
=
tokenization
.
convert_to_unicode
(
line
[
6
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
7
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
tokenization
.
convert_to_unicode
(
line
[
0
]))
text_a
=
tokenization
.
convert_to_unicode
(
line
[
8
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
4
])
if
set_type
==
"test"
:
label
=
"0"
else
:
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
"0"
else
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
class
ChnsenticorpProcessor
(
DataProcessor
):
"""Processor for the Chnsenticorp data set."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
def
convert_single_example_to_unicode
(
guid
,
single_example
):
text_a
=
tokenization
.
convert_to_unicode
(
single_example
[
0
])
text_b
=
tokenization
.
convert_to_unicode
(
single_example
[
1
])
label
=
tokenization
.
convert_to_unicode
(
single_example
[
2
])
return
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
)
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map
=
{}
for
(
i
,
label
)
in
enumerate
(
label_list
):
label_map
[
label
]
=
i
tokens_a
=
tokenizer
.
tokenize
(
example
.
text_a
)
tokens_b
=
None
if
example
.
text_b
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
3
)
else
:
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
tokens_a
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
if
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
label_id
=
label_map
[
example
.
label
]
feature
=
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
)
return
feature
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
):
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
print
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
feature
=
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
)
features
.
append
(
feature
)
return
features
if
__name__
==
'__main__'
:
pass
demo/bert-cls/reader/pretraining.py
0 → 100644
浏览文件 @
2bbf29ae
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
division
import
os
import
numpy
as
np
import
types
import
gzip
import
logging
import
re
import
six
import
collections
import
tokenization
import
paddle
import
paddle.fluid
as
fluid
from
batching
import
prepare_batch_data
class
DataReader
(
object
):
def
__init__
(
self
,
data_dir
,
vocab_path
,
batch_size
=
4096
,
in_tokens
=
True
,
max_seq_len
=
512
,
shuffle_files
=
True
,
epoch
=
100
,
voc_size
=
0
,
is_test
=
False
,
generate_neg_sample
=
False
):
self
.
vocab
=
self
.
load_vocab
(
vocab_path
)
self
.
data_dir
=
data_dir
self
.
batch_size
=
batch_size
self
.
in_tokens
=
in_tokens
self
.
shuffle_files
=
shuffle_files
self
.
epoch
=
epoch
self
.
current_epoch
=
0
self
.
current_file_index
=
0
self
.
total_file
=
0
self
.
current_file
=
None
self
.
voc_size
=
voc_size
self
.
max_seq_len
=
max_seq_len
self
.
pad_id
=
self
.
vocab
[
"[PAD]"
]
self
.
cls_id
=
self
.
vocab
[
"[CLS]"
]
self
.
sep_id
=
self
.
vocab
[
"[SEP]"
]
self
.
mask_id
=
self
.
vocab
[
"[MASK]"
]
self
.
is_test
=
is_test
self
.
generate_neg_sample
=
generate_neg_sample
if
self
.
in_tokens
:
assert
self
.
batch_size
>=
self
.
max_seq_len
,
"The number of "
\
"tokens in batch should not be smaller than max seq length."
if
self
.
is_test
:
self
.
epoch
=
1
self
.
shuffle_files
=
False
def
get_progress
(
self
):
"""return current progress of traning data
"""
return
self
.
current_epoch
,
self
.
current_file_index
,
self
.
total_file
,
self
.
current_file
def
parse_line
(
self
,
line
,
max_seq_len
=
512
):
""" parse one line to token_ids, sentence_ids, pos_ids, label
"""
line
=
line
.
strip
().
split
(
";"
)
assert
len
(
line
)
==
4
,
"One sample must have 4 fields!"
(
token_ids
,
sent_ids
,
pos_ids
,
label
)
=
line
token_ids
=
[
int
(
token
)
for
token
in
token_ids
.
split
(
" "
)]
sent_ids
=
[
int
(
token
)
for
token
in
sent_ids
.
split
(
" "
)]
pos_ids
=
[
int
(
token
)
for
token
in
pos_ids
.
split
(
" "
)]
assert
len
(
token_ids
)
==
len
(
sent_ids
)
==
len
(
pos_ids
),
"[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)"
label
=
int
(
label
)
if
len
(
token_ids
)
>
max_seq_len
:
return
None
return
[
token_ids
,
sent_ids
,
pos_ids
,
label
]
def
read_file
(
self
,
file
):
assert
file
.
endswith
(
'.gz'
),
"[ERROR] %s is not a gzip file"
%
file
file_path
=
self
.
data_dir
+
"/"
+
file
with
gzip
.
open
(
file_path
,
"rb"
)
as
f
:
for
line
in
f
:
parsed_line
=
self
.
parse_line
(
line
,
max_seq_len
=
self
.
max_seq_len
)
if
parsed_line
is
None
:
continue
yield
parsed_line
def
convert_to_unicode
(
self
,
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
self
,
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
fin
=
open
(
vocab_file
)
for
num
,
line
in
enumerate
(
fin
):
items
=
self
.
convert_to_unicode
(
line
.
strip
()).
split
(
"
\t
"
)
if
len
(
items
)
>
2
:
break
token
=
items
[
0
]
index
=
items
[
1
]
if
len
(
items
)
==
2
else
num
token
=
token
.
strip
()
vocab
[
token
]
=
int
(
index
)
return
vocab
def
random_pair_neg_samples
(
self
,
pos_samples
):
""" randomly generate negtive samples using pos_samples
Args:
pos_samples: list of positive samples
Returns:
neg_samples: list of negtive samples
"""
np
.
random
.
shuffle
(
pos_samples
)
num_sample
=
len
(
pos_samples
)
neg_samples
=
[]
miss_num
=
0
for
i
in
range
(
num_sample
):
pair_index
=
(
i
+
1
)
%
num_sample
origin_src_ids
=
pos_samples
[
i
][
0
]
origin_sep_index
=
origin_src_ids
.
index
(
2
)
pair_src_ids
=
pos_samples
[
pair_index
][
0
]
pair_sep_index
=
pair_src_ids
.
index
(
2
)
src_ids
=
origin_src_ids
[:
origin_sep_index
+
1
]
+
pair_src_ids
[
pair_sep_index
+
1
:]
if
len
(
src_ids
)
>=
self
.
max_seq_len
:
miss_num
+=
1
continue
sent_ids
=
[
0
]
*
len
(
origin_src_ids
[:
origin_sep_index
+
1
])
+
[
1
]
*
len
(
pair_src_ids
[
pair_sep_index
+
1
:])
pos_ids
=
list
(
range
(
len
(
src_ids
)))
neg_sample
=
[
src_ids
,
sent_ids
,
pos_ids
,
0
]
assert
len
(
src_ids
)
==
len
(
sent_ids
)
==
len
(
pos_ids
),
"[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True"
neg_samples
.
append
(
neg_sample
)
return
neg_samples
,
miss_num
def
mixin_negtive_samples
(
self
,
pos_sample_generator
,
buffer
=
1000
):
""" 1. generate negtive samples by randomly group sentence_1 and sentence_2 of positive samples
2. combine negtive samples and positive samples
Args:
pos_sample_generator: a generator producing a parsed positive sample, which is a list: [token_ids, sent_ids, pos_ids, 1]
Returns:
sample: one sample from shuffled positive samples and negtive samples
"""
pos_samples
=
[]
num_total_miss
=
0
pos_sample_num
=
0
try
:
while
True
:
while
len
(
pos_samples
)
<
buffer
:
pos_sample
=
next
(
pos_sample_generator
)
label
=
pos_sample
[
3
]
assert
label
==
1
,
"positive sample's label must be 1"
pos_samples
.
append
(
pos_sample
)
pos_sample_num
+=
1
neg_samples
,
miss_num
=
self
.
random_pair_neg_samples
(
pos_samples
)
num_total_miss
+=
miss_num
samples
=
pos_samples
+
neg_samples
pos_samples
=
[]
np
.
random
.
shuffle
(
samples
)
for
sample
in
samples
:
yield
sample
except
StopIteration
:
print
(
"stopiteration: reach end of file"
)
if
len
(
pos_samples
)
==
1
:
yield
pos_samples
[
0
]
elif
len
(
pos_samples
)
==
0
:
yield
None
else
:
neg_samples
,
miss_num
=
self
.
random_pair_neg_samples
(
pos_samples
)
num_total_miss
+=
miss_num
samples
=
pos_samples
+
neg_samples
pos_samples
=
[]
np
.
random
.
shuffle
(
samples
)
for
sample
in
samples
:
yield
sample
print
(
"miss_num:%d
\t
ideal_total_sample_num:%d
\t
miss_rate:%f"
%
(
num_total_miss
,
pos_sample_num
*
2
,
num_total_miss
/
(
pos_sample_num
*
2
)))
def
data_generator
(
self
):
"""
data_generator
"""
files
=
os
.
listdir
(
self
.
data_dir
)
self
.
total_file
=
len
(
files
)
assert
self
.
total_file
>
0
,
"[Error] data_dir is empty"
def
wrapper
():
def
reader
():
for
epoch
in
range
(
self
.
epoch
):
self
.
current_epoch
=
epoch
+
1
if
self
.
shuffle_files
:
np
.
random
.
shuffle
(
files
)
for
index
,
file
in
enumerate
(
files
):
self
.
current_file_index
=
index
+
1
self
.
current_file
=
file
sample_generator
=
self
.
read_file
(
file
)
if
not
self
.
is_test
and
self
.
generate_neg_sample
:
sample_generator
=
self
.
mixin_negtive_samples
(
sample_generator
)
for
sample
in
sample_generator
:
if
sample
is
None
:
continue
yield
sample
def
batch_reader
(
reader
,
batch_size
,
in_tokens
):
batch
,
total_token_num
,
max_len
=
[],
0
,
0
for
parsed_line
in
reader
():
token_ids
,
sent_ids
,
pos_ids
,
label
=
parsed_line
max_len
=
max
(
max_len
,
len
(
token_ids
))
if
in_tokens
:
to_append
=
(
len
(
batch
)
+
1
)
*
max_len
<=
batch_size
else
:
to_append
=
len
(
batch
)
<
batch_size
if
to_append
:
batch
.
append
(
parsed_line
)
total_token_num
+=
len
(
token_ids
)
else
:
yield
batch
,
total_token_num
batch
,
total_token_num
,
max_len
=
[
parsed_line
],
len
(
token_ids
),
len
(
token_ids
)
if
len
(
batch
)
>
0
:
yield
batch
,
total_token_num
for
batch_data
,
total_token_num
in
batch_reader
(
reader
,
self
.
batch_size
,
self
.
in_tokens
):
yield
prepare_batch_data
(
batch_data
,
total_token_num
,
voc_size
=
self
.
voc_size
,
pad_id
=
self
.
pad_id
,
cls_id
=
self
.
cls_id
,
sep_id
=
self
.
sep_id
,
mask_id
=
self
.
mask_id
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
return
wrapper
if
__name__
==
"__main__"
:
pass
demo/bert-cls/reader/squad.py
0 → 100644
浏览文件 @
2bbf29ae
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""
import
six
import
math
import
json
import
random
import
collections
import
tokenization
from
batching
import
prepare_batch_data
class
SquadExample
(
object
):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def
__init__
(
self
,
qas_id
,
question_text
,
doc_tokens
,
orig_answer_text
=
None
,
start_position
=
None
,
end_position
=
None
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
doc_tokens
=
doc_tokens
self
.
orig_answer_text
=
orig_answer_text
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
__str__
(
self
):
return
self
.
__repr__
()
def
__repr__
(
self
):
s
=
""
s
+=
"qas_id: %s"
%
(
tokenization
.
printable_text
(
self
.
qas_id
))
s
+=
", question_text: %s"
%
(
tokenization
.
printable_text
(
self
.
question_text
))
s
+=
", doc_tokens: [%s]"
%
(
" "
.
join
(
self
.
doc_tokens
))
if
self
.
start_position
:
s
+=
", start_position: %d"
%
(
self
.
start_position
)
if
self
.
start_position
:
s
+=
", end_position: %d"
%
(
self
.
end_position
)
if
self
.
start_position
:
s
+=
", is_impossible: %r"
%
(
self
.
is_impossible
)
return
s
class
InputFeatures
(
object
):
"""A single set of features of data."""
def
__init__
(
self
,
unique_id
,
example_index
,
doc_span_index
,
tokens
,
token_to_orig_map
,
token_is_max_context
,
input_ids
,
input_mask
,
segment_ids
,
start_position
=
None
,
end_position
=
None
,
is_impossible
=
None
):
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
self
.
tokens
=
tokens
self
.
token_to_orig_map
=
token_to_orig_map
self
.
token_is_max_context
=
token_is_max_context
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
=
False
):
"""Read a SQuAD json file into a list of SquadExample."""
with
open
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
def
is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
False
examples
=
[]
for
entry
in
input_data
:
for
paragraph
in
entry
[
"paragraphs"
]:
paragraph_text
=
paragraph
[
"context"
]
doc_tokens
=
[]
char_to_word_offset
=
[]
prev_is_whitespace
=
True
for
c
in
paragraph_text
:
if
is_whitespace
(
c
):
prev_is_whitespace
=
True
else
:
if
prev_is_whitespace
:
doc_tokens
.
append
(
c
)
else
:
doc_tokens
[
-
1
]
+=
c
prev_is_whitespace
=
False
char_to_word_offset
.
append
(
len
(
doc_tokens
)
-
1
)
for
qa
in
paragraph
[
"qas"
]:
qas_id
=
qa
[
"id"
]
question_text
=
qa
[
"question"
]
start_position
=
None
end_position
=
None
orig_answer_text
=
None
is_impossible
=
False
if
is_training
:
if
version_2_with_negative
:
is_impossible
=
qa
[
"is_impossible"
]
if
(
len
(
qa
[
"answers"
])
!=
1
)
and
(
not
is_impossible
):
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
if
not
is_impossible
:
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_length
=
len
(
orig_answer_text
)
start_position
=
char_to_word_offset
[
answer_offset
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
print
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
else
:
start_position
=
-
1
end_position
=
-
1
orig_answer_text
=
""
example
=
SquadExample
(
qas_id
=
qas_id
,
question_text
=
question_text
,
doc_tokens
=
doc_tokens
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
end_position
=
end_position
,
is_impossible
=
is_impossible
)
examples
.
append
(
example
)
return
examples
def
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
#output_fn
):
"""Loads a data file into a list of `InputBatch`s."""
unique_id
=
1000000000
for
(
example_index
,
example
)
in
enumerate
(
examples
):
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
tok_start_position
=
None
tok_end_position
=
None
if
is_training
and
example
.
is_impossible
:
tok_start_position
=
-
1
tok_end_position
=
-
1
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
orig_answer_text
)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"DocSpan"
,
[
"start"
,
"length"
])
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
break
start_offset
+=
min
(
length
,
doc_stride
)
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
query_tokens
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
#while len(input_ids) < max_seq_length:
# input_ids.append(0)
# input_mask.append(0)
# segment_ids.append(0)
#assert len(input_ids) == max_seq_length
#assert len(input_mask) == max_seq_length
#assert len(segment_ids) == max_seq_length
start_position
=
None
end_position
=
None
if
is_training
and
not
example
.
is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
start_position
=
0
end_position
=
0
else
:
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
example
.
is_impossible
:
start_position
=
0
end_position
=
0
if
example_index
<
3
:
print
(
"*** Example ***"
)
print
(
"unique_id: %s"
%
(
unique_id
))
print
(
"example_index: %s"
%
(
example_index
))
print
(
"doc_span_index: %s"
%
(
doc_span_index
))
print
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
print
(
"token_to_orig_map: %s"
%
" "
.
join
([
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_to_orig_map
)
]))
print
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_is_max_context
)
]))
print
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
print
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
print
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
and
example
.
is_impossible
:
print
(
"impossible example"
)
if
is_training
and
not
example
.
is_impossible
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
print
(
"start_position: %d"
%
(
start_position
))
print
(
"end_position: %d"
%
(
end_position
))
print
(
"answer: %s"
%
(
tokenization
.
printable_text
(
answer_text
)))
feature
=
InputFeatures
(
unique_id
=
unique_id
,
example_index
=
example_index
,
doc_span_index
=
doc_span_index
,
tokens
=
tokens
,
token_to_orig_map
=
token_to_orig_map
,
token_is_max_context
=
token_is_max_context
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
start_position
=
start_position
,
end_position
=
end_position
,
is_impossible
=
example
.
is_impossible
)
unique_id
+=
1
yield
feature
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# often find a "better match". For example:
#
# Question: What year was John Smith born?
# Context: The leader was John Smith (1895-1943).
# Answer: 1895
#
# The original whitespace-tokenized answer will be "(1895-1943).". However
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# the exact answer, 1895.
#
# However, this is not always possible. Consider the following:
#
# Question: What country is the top exporter of electornics?
# Context: The Japanese electronics industry is the lagest in the world.
# Answer: Japan
#
# In this case, the annotator chose "Japan" as a character sub-span of
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:(
new_end
+
1
)])
if
text_span
==
tok_answer_text
:
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
position
<
doc_span
.
start
:
continue
if
position
>
end
:
continue
num_left_context
=
position
-
doc_span
.
start
num_right_context
=
end
-
position
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
.
length
if
best_score
is
None
or
score
>
best_score
:
best_score
=
score
best_span_index
=
span_index
return
cur_span_index
==
best_span_index
class
DataProcessor
(
object
):
def
__init__
(
self
,
vocab_path
,
do_lower_case
,
max_seq_length
,
in_tokens
,
doc_stride
,
max_query_length
):
self
.
_tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_path
,
do_lower_case
=
do_lower_case
)
self
.
_max_seq_length
=
max_seq_length
self
.
_doc_stride
=
doc_stride
self
.
_max_query_length
=
max_query_length
self
.
_in_tokens
=
in_tokens
self
.
vocab
=
self
.
_tokenizer
.
vocab
self
.
vocab_size
=
len
(
self
.
vocab
)
self
.
pad_id
=
self
.
vocab
[
"[PAD]"
]
self
.
cls_id
=
self
.
vocab
[
"[CLS]"
]
self
.
sep_id
=
self
.
vocab
[
"[SEP]"
]
self
.
mask_id
=
self
.
vocab
[
"[MASK]"
]
self
.
current_train_example
=
-
1
self
.
num_train_examples
=
-
1
self
.
current_train_epoch
=
-
1
self
.
train_examples
=
None
self
.
predict_examples
=
None
self
.
num_examples
=
{
'train'
:
-
1
,
'predict'
:
-
1
}
def
get_train_progress
(
self
):
"""Gets progress for training phase."""
return
self
.
current_train_example
,
self
.
current_train_epoch
def
get_examples
(
self
,
data_path
,
is_training
,
version_2_with_negative
=
False
):
examples
=
read_squad_examples
(
input_file
=
data_path
,
is_training
=
is_training
,
version_2_with_negative
=
version_2_with_negative
)
return
examples
def
get_num_examples
(
self
,
phase
):
if
phase
not
in
[
'train'
,
'predict'
]:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'predict']."
)
return
self
.
num_examples
[
phase
]
def
get_features
(
self
,
examples
,
is_training
):
features
=
convert_examples_to_features
(
examples
=
examples
,
tokenizer
=
self
.
_tokenizer
,
max_seq_length
=
self
.
_max_seq_length
,
doc_stride
=
self
.
_doc_stride
,
max_query_length
=
self
.
_max_query_length
,
is_training
=
is_training
)
return
features
def
data_generator
(
self
,
data_path
,
batch_size
,
phase
=
'train'
,
shuffle
=
False
,
version_2_with_negative
=
False
,
epoch
=
1
):
if
phase
==
'train'
:
self
.
train_examples
=
self
.
get_examples
(
data_path
,
is_training
=
True
,
version_2_with_negative
=
version_2_with_negative
)
examples
=
self
.
train_examples
self
.
num_examples
[
'train'
]
=
len
(
self
.
train_examples
)
elif
phase
==
'predict'
:
self
.
predict_examples
=
self
.
get_examples
(
data_path
,
is_training
=
False
,
version_2_with_negative
=
version_2_with_negative
)
examples
=
self
.
predict_examples
self
.
num_examples
[
'predict'
]
=
len
(
self
.
predict_examples
)
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'predict']."
)
def
batch_reader
(
features
,
batch_size
,
in_tokens
):
batch
,
total_token_num
,
max_len
=
[],
0
,
0
for
(
index
,
feature
)
in
enumerate
(
features
):
if
phase
==
'train'
:
self
.
current_train_example
=
index
+
1
seq_len
=
len
(
feature
.
input_ids
)
labels
=
[
feature
.
unique_id
]
if
feature
.
start_position
is
None
else
[
feature
.
start_position
,
feature
.
end_position
]
example
=
[
feature
.
input_ids
,
feature
.
segment_ids
,
range
(
seq_len
)
]
+
labels
max_len
=
max
(
max_len
,
seq_len
)
#max_len = max(max_len, len(token_ids))
if
in_tokens
:
to_append
=
(
len
(
batch
)
+
1
)
*
max_len
<=
batch_size
else
:
to_append
=
len
(
batch
)
<
batch_size
if
to_append
:
batch
.
append
(
example
)
total_token_num
+=
seq_len
else
:
yield
batch
,
total_token_num
batch
,
total_token_num
,
max_len
=
[
example
],
seq_len
,
seq_len
if
len
(
batch
)
>
0
:
yield
batch
,
total_token_num
def
wrapper
():
for
epoch_index
in
range
(
epoch
):
if
shuffle
:
random
.
shuffle
(
examples
)
if
phase
==
'train'
:
self
.
current_train_epoch
=
epoch_index
features
=
self
.
get_features
(
examples
,
is_training
=
True
)
else
:
features
=
self
.
get_features
(
examples
,
is_training
=
False
)
for
batch_data
,
total_token_num
in
batch_reader
(
features
,
batch_size
,
self
.
_in_tokens
):
yield
prepare_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
pad_id
=
self
.
pad_id
,
cls_id
=
self
.
cls_id
,
sep_id
=
self
.
sep_id
,
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
return
wrapper
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
version_2_with_negative
,
null_score_diff_threshold
,
verbose
):
"""Write final predictions to the json file and log-odds of null if needed."""
print
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
print
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
unique_id_to_result
=
{}
for
result
in
all_results
:
unique_id_to_result
[
result
.
unique_id
]
=
result
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
])
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
prelim_predictions
=
[]
# keep track of the minimum score of null start+end of position 0
score_null
=
1000000
# large and positive
min_null_feature_index
=
0
# the paragraph slice with min mull score
null_start_logit
=
0
# the start logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
if
version_2_with_negative
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if
start_index
>=
len
(
feature
.
tokens
):
continue
if
end_index
>=
len
(
feature
.
tokens
):
continue
if
start_index
not
in
feature
.
token_to_orig_map
:
continue
if
end_index
not
in
feature
.
token_to_orig_map
:
continue
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
if
end_index
<
start_index
:
continue
length
=
end_index
-
start_index
+
1
if
length
>
max_answer_length
:
continue
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
feature_index
,
start_index
=
start_index
,
end_index
=
end_index
,
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
if
version_2_with_negative
:
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
start_index
=
0
,
end_index
=
0
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
reverse
=
True
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
])
seen_predictions
=
{}
nbest
=
[]
for
pred
in
prelim_predictions
:
if
len
(
nbest
)
>=
n_best_size
:
break
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>
0
:
# this is a non-null prediction
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
" "
.
join
(
tok_tokens
)
# De-tokenize WordPieces that have been split off.
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
# Clean whitespace
tok_text
=
tok_text
.
strip
()
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
,
verbose
)
if
final_text
in
seen_predictions
:
continue
seen_predictions
[
final_text
]
=
True
else
:
final_text
=
""
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, inlcude it
if
version_2_with_negative
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
assert
len
(
nbest
)
>=
1
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
if
not
best_non_null_entry
:
if
entry
.
text
:
best_non_null_entry
=
entry
# debug
if
best_non_null_entry
is
None
:
print
(
"Emmm..., sth wrong"
)
probs
=
_compute_softmax
(
total_scores
)
nbest_json
=
[]
for
(
i
,
entry
)
in
enumerate
(
nbest
):
output
=
collections
.
OrderedDict
()
output
[
"text"
]
=
entry
.
text
output
[
"probability"
]
=
probs
[
i
]
output
[
"start_logit"
]
=
entry
.
start_logit
output
[
"end_logit"
]
=
entry
.
end_logit
nbest_json
.
append
(
output
)
assert
len
(
nbest_json
)
>=
1
if
not
version_2_with_negative
:
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
open
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose
):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def
_strip_spaces
(
text
):
ns_chars
=
[]
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
verbose
:
print
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
verbose
:
print
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map
=
{}
for
(
i
,
tok_index
)
in
six
.
iteritems
(
tok_ns_to_s_map
):
tok_s_to_ns_map
[
tok_index
]
=
i
orig_start_position
=
None
if
start_position
in
tok_s_to_ns_map
:
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
if
verbose
:
print
(
"Couldn't map start position"
)
return
orig_text
orig_end_position
=
None
if
end_position
in
tok_s_to_ns_map
:
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
if
verbose
:
print
(
"Couldn't map end position"
)
return
orig_text
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
return
output_text
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
if
not
scores
:
return
[]
max_score
=
None
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
exp_scores
=
[]
total_sum
=
0.0
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
probs
=
[]
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
return
probs
if
__name__
==
'__main__'
:
train_file
=
'squad/train-v1.1.json'
vocab_file
=
'uncased_L-12_H-768_A-12/vocab.txt'
do_lower_case
=
True
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_file
,
do_lower_case
=
do_lower_case
)
train_examples
=
read_squad_examples
(
input_file
=
train_file
,
is_training
=
True
)
print
(
"begin converting"
)
for
(
index
,
feature
)
in
enumerate
(
convert_examples_to_features
(
examples
=
train_examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
384
,
doc_stride
=
128
,
max_query_length
=
64
,
is_training
=
True
,
#output_fn=train_writer.process_feature
)):
if
index
<
10
:
print
(
index
,
feature
.
input_ids
,
feature
.
input_mask
,
feature
.
segment_ids
)
#for (index, example) in enumerate(train_examples):
# if index < 5:
# print(example)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录