提交 2bbf29ae 编写于 作者: Z Zeyu Chen

add bert-reader

上级 490d9e4e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import types
import csv
import numpy as np
import tokenization
from batching import prepare_batch_data
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self,
data_dir,
vocab_path,
max_seq_len,
do_lower_case,
in_tokens,
random_seed=None):
self.data_dir = data_dir
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.in_tokens = in_tokens
np.random.seed(random_seed)
self.current_train_example = -1
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def convert_example(self, index, example, labels, max_seq_len, tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
feature = convert_single_example(index, example, labels, max_seq_len,
tokenizer)
return feature
def generate_instance(self, feature):
"""
generate instance with given feature
Args:
feature: InputFeatures(object). A single set of features of data.
"""
input_pos = list(range(len(feature.input_ids)))
return [
feature.input_ids, feature.segment_ids, input_pos, feature.label_id
]
def generate_batch_data(self,
batch_data,
total_token_num,
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False):
return prepare_batch_data(
batch_data,
total_token_num,
voc_size=-1,
pad_id=self.vocab["[PAD]"],
cls_id=self.vocab["[CLS]"],
sep_id=self.vocab["[SEP]"],
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
return self.num_examples[phase]
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch
def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True):
"""
Generate data for train, dev or test.
Args:
batch_size: int. The batch size of generated data.
phase: string. The phase for which to generate data.
epoch: int. Total epoches to generate data.
shuffle: bool. Whether to shuffle examples.
"""
if phase == 'train':
examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev':
examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples(self.data_dir)
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
for epoch_index in range(epoch):
if shuffle:
np.random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
for (index, example) in enumerate(examples):
if phase == 'train':
self.current_train_example = index + 1
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len,
self.tokenizer)
instance = self.generate_instance(feature)
yield instance
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for instance in reader():
token_ids, sent_ids, pos_ids, label = instance[:4]
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(instance)
total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [
instance
], len(token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
batch_data = self.generate_batch_data(
batch_data,
total_token_num,
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
yield batch_data
return wrapper
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def get_train_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(
line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class ChnsenticorpProcessor(DataProcessor):
"""Processor for the Chnsenticorp data set."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example_to_unicode(guid, single_example):
text_a = tokenization.convert_to_unicode(single_example[0])
text_b = tokenization.convert_to_unicode(single_example[1])
label = tokenization.convert_to_unicode(single_example[2])
return InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
label_id = label_map[example.label]
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id)
return feature
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
print("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
features.append(feature)
return features
if __name__ == '__main__':
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import os
import numpy as np
import types
import gzip
import logging
import re
import six
import collections
import tokenization
import paddle
import paddle.fluid as fluid
from batching import prepare_batch_data
class DataReader(object):
def __init__(self,
data_dir,
vocab_path,
batch_size=4096,
in_tokens=True,
max_seq_len=512,
shuffle_files=True,
epoch=100,
voc_size=0,
is_test=False,
generate_neg_sample=False):
self.vocab = self.load_vocab(vocab_path)
self.data_dir = data_dir
self.batch_size = batch_size
self.in_tokens = in_tokens
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.current_file = None
self.voc_size = voc_size
self.max_seq_len = max_seq_len
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
if self.in_tokens:
assert self.batch_size >= self.max_seq_len, "The number of " \
"tokens in batch should not be smaller than max seq length."
if self.is_test:
self.epoch = 1
self.shuffle_files = False
def get_progress(self):
"""return current progress of traning data
"""
return self.current_epoch, self.current_file_index, self.total_file, self.current_file
def parse_line(self, line, max_seq_len=512):
""" parse one line to token_ids, sentence_ids, pos_ids, label
"""
line = line.strip().split(";")
assert len(line) == 4, "One sample must have 4 fields!"
(token_ids, sent_ids, pos_ids, label) = line
token_ids = [int(token) for token in token_ids.split(" ")]
sent_ids = [int(token) for token in sent_ids.split(" ")]
pos_ids = [int(token) for token in pos_ids.split(" ")]
assert len(token_ids) == len(sent_ids) == len(
pos_ids
), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)"
label = int(label)
if len(token_ids) > max_seq_len:
return None
return [token_ids, sent_ids, pos_ids, label]
def read_file(self, file):
assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file
file_path = self.data_dir + "/" + file
with gzip.open(file_path, "rb") as f:
for line in f:
parsed_line = self.parse_line(
line, max_seq_len=self.max_seq_len)
if parsed_line is None:
continue
yield parsed_line
def convert_to_unicode(self, text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = self.convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def random_pair_neg_samples(self, pos_samples):
""" randomly generate negtive samples using pos_samples
Args:
pos_samples: list of positive samples
Returns:
neg_samples: list of negtive samples
"""
np.random.shuffle(pos_samples)
num_sample = len(pos_samples)
neg_samples = []
miss_num = 0
for i in range(num_sample):
pair_index = (i + 1) % num_sample
origin_src_ids = pos_samples[i][0]
origin_sep_index = origin_src_ids.index(2)
pair_src_ids = pos_samples[pair_index][0]
pair_sep_index = pair_src_ids.index(2)
src_ids = origin_src_ids[:origin_sep_index +
1] + pair_src_ids[pair_sep_index + 1:]
if len(src_ids) >= self.max_seq_len:
miss_num += 1
continue
sent_ids = [0] * len(origin_src_ids[:origin_sep_index + 1]) + [
1
] * len(pair_src_ids[pair_sep_index + 1:])
pos_ids = list(range(len(src_ids)))
neg_sample = [src_ids, sent_ids, pos_ids, 0]
assert len(src_ids) == len(sent_ids) == len(
pos_ids
), "[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True"
neg_samples.append(neg_sample)
return neg_samples, miss_num
def mixin_negtive_samples(self, pos_sample_generator, buffer=1000):
""" 1. generate negtive samples by randomly group sentence_1 and sentence_2 of positive samples
2. combine negtive samples and positive samples
Args:
pos_sample_generator: a generator producing a parsed positive sample, which is a list: [token_ids, sent_ids, pos_ids, 1]
Returns:
sample: one sample from shuffled positive samples and negtive samples
"""
pos_samples = []
num_total_miss = 0
pos_sample_num = 0
try:
while True:
while len(pos_samples) < buffer:
pos_sample = next(pos_sample_generator)
label = pos_sample[3]
assert label == 1, "positive sample's label must be 1"
pos_samples.append(pos_sample)
pos_sample_num += 1
neg_samples, miss_num = self.random_pair_neg_samples(
pos_samples)
num_total_miss += miss_num
samples = pos_samples + neg_samples
pos_samples = []
np.random.shuffle(samples)
for sample in samples:
yield sample
except StopIteration:
print("stopiteration: reach end of file")
if len(pos_samples) == 1:
yield pos_samples[0]
elif len(pos_samples) == 0:
yield None
else:
neg_samples, miss_num = self.random_pair_neg_samples(
pos_samples)
num_total_miss += miss_num
samples = pos_samples + neg_samples
pos_samples = []
np.random.shuffle(samples)
for sample in samples:
yield sample
print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" %
(num_total_miss, pos_sample_num * 2, num_total_miss /
(pos_sample_num * 2)))
def data_generator(self):
"""
data_generator
"""
files = os.listdir(self.data_dir)
self.total_file = len(files)
assert self.total_file > 0, "[Error] data_dir is empty"
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
if self.shuffle_files:
np.random.shuffle(files)
for index, file in enumerate(files):
self.current_file_index = index + 1
self.current_file = file
sample_generator = self.read_file(file)
if not self.is_test and self.generate_neg_sample:
sample_generator = self.mixin_negtive_samples(
sample_generator)
for sample in sample_generator:
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [
parsed_line
], len(token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
if __name__ == "__main__":
pass
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""
import six
import math
import json
import random
import collections
import tokenization
from batching import prepare_batch_data
class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (tokenization.printable_text(
self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training, version_2_with_negative=False):
"""Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r") as reader:
input_data = json.load(reader)["data"]
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in paragraph_text:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
end_position = None
orig_answer_text = None
is_impossible = False
if is_training:
if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer."
)
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset +
answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
print("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible)
examples.append(example)
return examples
def convert_examples_to_features(
examples,
tokenizer,
max_seq_length,
doc_stride,
max_query_length,
is_training,
#output_fn
):
"""Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000
for (example_index, example) in enumerate(examples):
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
tok_start_position = None
tok_end_position = None
if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position +
1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(
tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(
doc_spans, doc_span_index, split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
#while len(input_ids) < max_seq_length:
# input_ids.append(0)
# input_mask.append(0)
# segment_ids.append(0)
#assert len(input_ids) == max_seq_length
#assert len(input_mask) == max_seq_length
#assert len(segment_ids) == max_seq_length
start_position = None
end_position = None
if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start
and tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 3:
print("*** Example ***")
print("unique_id: %s" % (unique_id))
print("example_index: %s" % (example_index))
print("doc_span_index: %s" % (doc_span_index))
print("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
print("token_to_orig_map: %s" % " ".join([
"%d:%d" % (x, y)
for (x, y) in six.iteritems(token_to_orig_map)
]))
print("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y)
for (x, y) in six.iteritems(token_is_max_context)
]))
print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
print("input_mask: %s" % " ".join([str(x) for x in input_mask]))
print(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible:
print("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(
tokens[start_position:(end_position + 1)])
print("start_position: %d" % (start_position))
print("end_position: %d" % (end_position))
print("answer: %s" %
(tokenization.printable_text(answer_text)))
feature = InputFeatures(
unique_id=unique_id,
example_index=example_index,
doc_span_index=doc_span_index,
tokens=tokens,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
start_position=start_position,
end_position=end_position,
is_impossible=example.is_impossible)
unique_id += 1
yield feature
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# often find a "better match". For example:
#
# Question: What year was John Smith born?
# Context: The leader was John Smith (1895-1943).
# Answer: 1895
#
# The original whitespace-tokenized answer will be "(1895-1943).". However
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# the exact answer, 1895.
#
# However, this is not always possible. Consider the following:
#
# Question: What country is the top exporter of electornics?
# Context: The Japanese electronics industry is the lagest in the world.
# Answer: Japan
#
# In this case, the annotator chose "Japan" as a character sub-span of
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
if text_span == tok_answer_text:
return (new_start, new_end)
return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span.start + doc_span.length - 1
if position < doc_span.start:
continue
if position > end:
continue
num_left_context = position - doc_span.start
num_right_context = end - position
score = min(num_left_context,
num_right_context) + 0.01 * doc_span.length
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
class DataProcessor(object):
def __init__(self, vocab_path, do_lower_case, max_seq_length, in_tokens,
doc_stride, max_query_length):
self._tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self._max_seq_length = max_seq_length
self._doc_stride = doc_stride
self._max_query_length = max_query_length
self._in_tokens = in_tokens
self.vocab = self._tokenizer.vocab
self.vocab_size = len(self.vocab)
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.current_train_example = -1
self.num_train_examples = -1
self.current_train_epoch = -1
self.train_examples = None
self.predict_examples = None
self.num_examples = {'train': -1, 'predict': -1}
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch
def get_examples(self,
data_path,
is_training,
version_2_with_negative=False):
examples = read_squad_examples(
input_file=data_path,
is_training=is_training,
version_2_with_negative=version_2_with_negative)
return examples
def get_num_examples(self, phase):
if phase not in ['train', 'predict']:
raise ValueError(
"Unknown phase, which should be in ['train', 'predict'].")
return self.num_examples[phase]
def get_features(self, examples, is_training):
features = convert_examples_to_features(
examples=examples,
tokenizer=self._tokenizer,
max_seq_length=self._max_seq_length,
doc_stride=self._doc_stride,
max_query_length=self._max_query_length,
is_training=is_training)
return features
def data_generator(self,
data_path,
batch_size,
phase='train',
shuffle=False,
version_2_with_negative=False,
epoch=1):
if phase == 'train':
self.train_examples = self.get_examples(
data_path,
is_training=True,
version_2_with_negative=version_2_with_negative)
examples = self.train_examples
self.num_examples['train'] = len(self.train_examples)
elif phase == 'predict':
self.predict_examples = self.get_examples(
data_path,
is_training=False,
version_2_with_negative=version_2_with_negative)
examples = self.predict_examples
self.num_examples['predict'] = len(self.predict_examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'predict'].")
def batch_reader(features, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for (index, feature) in enumerate(features):
if phase == 'train':
self.current_train_example = index + 1
seq_len = len(feature.input_ids)
labels = [feature.unique_id
] if feature.start_position is None else [
feature.start_position, feature.end_position
]
example = [
feature.input_ids, feature.segment_ids,
range(seq_len)
] + labels
max_len = max(max_len, seq_len)
#max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(example)
total_token_num += seq_len
else:
yield batch, total_token_num
batch, total_token_num, max_len = [example
], seq_len, seq_len
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
for epoch_index in range(epoch):
if shuffle:
random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
features = self.get_features(examples, is_training=True)
else:
features = self.get_features(examples, is_training=False)
for batch_data, total_token_num in batch_reader(
features, batch_size, self._in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=-1,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
version_2_with_negative, null_score_diff_threshold,
verbose):
"""Write final predictions to the json file and log-odds of null if needed."""
print("Writing predictions to: %s" % (output_prediction_file))
print("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_logit",
"end_logit"
])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[
0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(
pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(
orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case,
verbose)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
# debug
if best_non_null_entry is None:
print("Emmm..., sth wrong")
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose:
print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose:
print("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose:
print("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose:
print("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(
enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
if __name__ == '__main__':
train_file = 'squad/train-v1.1.json'
vocab_file = 'uncased_L-12_H-768_A-12/vocab.txt'
do_lower_case = True
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
train_examples = read_squad_examples(
input_file=train_file, is_training=True)
print("begin converting")
for (index, feature) in enumerate(
convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=384,
doc_stride=128,
max_query_length=64,
is_training=True,
#output_fn=train_writer.process_feature
)):
if index < 10:
print(index, feature.input_ids, feature.input_mask,
feature.segment_ids)
#for (index, example) in enumerate(train_examples):
# if index < 5:
# print(example)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册