diff --git a/paddleslim/nas/darts/search_space/conv_bert/cls.py b/paddleslim/nas/darts/search_space/conv_bert/cls.py index 80c000320f97512b6d75178f255394a32152c0bf..f5a4ad6eb24ba747043471cc2cd7cbfcf16659e5 100755 --- a/paddleslim/nas/darts/search_space/conv_bert/cls.py +++ b/paddleslim/nas/darts/search_space/conv_bert/cls.py @@ -83,23 +83,34 @@ class AdaBERTClassifier(Layer): sentence_ids = data_ids[2] input_mask = data_ids[3] labels = data_ids[4] - enc_outputs, next_sent_feats = self.student(src_ids, position_ids, - sentence_ids) + flops = [] + model_size = [] + enc_outputs, next_sent_feats, k_i = self.student( + src_ids, + position_ids, + sentence_ids, + flops=flops, + model_size=model_size) self.teacher.eval() - total_loss, logits, losses, accuracys, num_seqs = self.teacher( + total_loss, t_logits, t_losses, accuracys, num_seqs = self.teacher( data_ids) + # define kd loss kd_losses = [] - for t_logits, t_loss, s_sent_feat, fc in zip( - logits, losses, next_sent_feats, self.cls_fc): + for i in range(len(next_sent_feats)): + j = np.ceil(i * (len(next_sent_feats) / len(logits))) + t_logit = t_logits[j] + t_loss = t_losses[j] + s_sent_feat = next_sent_feats[i] + fc = self.cls_fc[i] s_sent_feat = fluid.layers.dropout( x=s_sent_feat, dropout_prob=0.1, dropout_implementation="upscale_in_train") s_logits = fc(s_sent_feat) - t_probs = fluid.layers.softmax(t_logits) + t_probs = fluid.layers.softmax(t_logit) s_probs = fluid.layers.softmax(s_logits) t_probs.stop_gradient = False kd_loss = t_probs * fluid.layers.log(s_probs / T) @@ -110,9 +121,16 @@ class AdaBERTClassifier(Layer): kd_loss = fluid.layers.sum(kd_losses) + # define ce loss ce_loss = fluid.layers.cross_entropy(s_probs, labels) - ce_loss = fluid.layers.mean(x=ce_loss) + ce_loss = fluid.layers.mean(x=ce_loss) * k_i - e_loss = 1 # to be done + # define e loss + model_size = fluid.layers.sum(model_size) + flops = fluid.layers.sum(flops) + e_loss = (len(next_sent_feats) * k_i / self._n_layer) * ( + flops + model_size) + + # define total loss loss = (1 - gamma) * ce_loss - gamma * kd_loss + beta * e_loss return loss diff --git a/paddleslim/nas/darts/search_space/conv_bert/model/bert.py b/paddleslim/nas/darts/search_space/conv_bert/model/bert.py index 377fcf4d8c9728a60254f6bd75982d24ac2c1e20..d8c5ce680a2ea21886d6f03a2ddb37a97144c51a 100644 --- a/paddleslim/nas/darts/search_space/conv_bert/model/bert.py +++ b/paddleslim/nas/darts/search_space/conv_bert/model/bert.py @@ -85,7 +85,12 @@ class BertModelLayer(Layer): def arch_parameters(self): return [self._encoder.alphas] - def forward(self, src_ids, position_ids, sentence_ids): + def forward(self, + src_ids, + position_ids, + sentence_ids, + flops=[], + model_size=[]): """ forward """ @@ -96,7 +101,8 @@ class BertModelLayer(Layer): emb_out = src_emb + pos_emb emb_out = emb_out + sent_emb - enc_outputs = self._encoder(emb_out) + enc_outputs, k_i = self._encoder( + emb_out, flops=flops, model_size=model_size) if not self.return_pooled_out: return enc_outputs @@ -109,4 +115,4 @@ class BertModelLayer(Layer): next_sent_feat, shape=[-1, self._emb_size]) next_sent_feats.append(next_sent_feat) - return enc_outputs, next_sent_feats + return enc_outputs, next_sent_feats, k_i diff --git a/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py b/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py index 760d36a45d7cc77a24ea139dda49727b0824be76..9de8f7bc5479acb0a6af12f23f5ce84ece091e78 100644 --- a/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py +++ b/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py @@ -29,6 +29,33 @@ PRIMITIVES = [ 'dil_conv_7', 'avg_pool_3', 'max_pool_3', 'none', 'skip_connect' ] +input_size = 128 * 768 + +FLOPs = { + 'std_conv_3': input_size * 3 * 1, + 'std_conv_5': input_size * 5 * 1, + 'std_conv_7': input_size * 7 * 1, + 'dil_conv_3': input_size * 3 * 1, + 'dil_conv_5': input_size * 5 * 1, + 'dil_conv_7': input_size * 7 * 1, + 'avg_pool_3': input_size * 3 * 1, + 'max_pool_3': input_size * 3 * 1, + 'none': 0, + 'skip_connect': 0, +} + +ModelSize = { + 'std_conv_3': 3 * 1, + 'std_conv_5': 5 * 1, + 'std_conv_7': 7 * 1, + 'dil_conv_3': 3 * 1, + 'dil_conv_5': 5 * 1, + 'dil_conv_7': 7 * 1, + 'avg_pool_3': 0, + 'max_pool_3': 0, + 'none': 0, + 'skip_connect': 0, +} OPS = { 'std_conv_3': lambda : ConvBN(1, 1, filter_size=3, dilation=1), @@ -50,9 +77,11 @@ class MixedOp(fluid.dygraph.Layer): ops = [OPS[primitive]() for primitive in PRIMITIVES] self._ops = fluid.dygraph.LayerList(ops) - def forward(self, x, weights): + def forward(self, x, weights, flops=[], model_size=[]): for i in range(len(self._ops)): if weights[i] != 0: + flops.append(FLOPs.values()[i] * weights[i]) + model_size.append(ModelSize.values()[i] * weights[i]) return self._ops[i](x) * weights[i] @@ -132,13 +161,16 @@ class Cell(fluid.dygraph.Layer): ops.append(op) self._ops = fluid.dygraph.LayerList(ops) - def forward(self, s0, s1, weights, weights2=None): + def forward(self, s0, s1, weights, weights2=None, flops=[], model_size=[]): states = [s0, s1] offset = 0 for i in range(self._steps): s = fluid.layers.sums([ - self._ops[offset + j](h, weights[offset + j]) + self._ops[offset + j](h, + weights[offset + j], + flops=flops, + model_size=model_size) for j, h in enumerate(states) ]) offset += len(states) @@ -173,7 +205,13 @@ class EncoderLayer(Layer): default_initializer=NormalInitializer( loc=0.0, scale=1e-3)) - def forward(self, enc_input): + self.k = fluid.layers.create_parameter( + shape=[1, self._n_layer], + dtype="float32", + default_initializer=NormalInitializer( + loc=0.0, scale=1e-3)) + + def forward(self, enc_input, flops=[], model_size=[]): """ forward :param enc_input: @@ -184,12 +222,20 @@ class EncoderLayer(Layer): [-1, 1, enc_input.shape[1], self._d_model]) alphas = gumbel_softmax(self.alphas) + k = gumbel_softmax(self.k) outputs = [] s0 = s1 = tmp - for i, cell in enumerate(self._cells): - s0, s1 = s1, cell(s0, s1, alphas) + for i in range(self._n_layer): + s0, s1 = s1, self._cells[i](s0, + s1, + alphas, + flops=flops, + model_size=model_size) enc_output = fluid.layers.reshape( s1, [-1, enc_input.shape[1], self._d_model]) outputs.append(enc_output) - return outputs + if k[i] != 0: + outputs[-1] = outputs[-1] * k[i] + break + return outputs, k[i] diff --git a/paddleslim/teachers/bert/reader/pretraining.py b/paddleslim/teachers/bert/reader/pretraining.py deleted file mode 100644 index c21a43d33caedd9a01c02dacbedd01a16e1eec9f..0000000000000000000000000000000000000000 --- a/paddleslim/teachers/bert/reader/pretraining.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function -from __future__ import division - -import os -import numpy as np -import types -import gzip -import logging -import re -import six -import collections -import tokenization - -import paddle -import paddle.fluid as fluid - -from batching import prepare_batch_data - - -class DataReader(object): - def __init__(self, - data_dir, - vocab_path, - batch_size=4096, - in_tokens=True, - max_seq_len=512, - shuffle_files=True, - epoch=100, - voc_size=0, - is_test=False, - generate_neg_sample=False): - - self.vocab = self.load_vocab(vocab_path) - self.data_dir = data_dir - self.batch_size = batch_size - self.in_tokens = in_tokens - self.shuffle_files = shuffle_files - self.epoch = epoch - self.current_epoch = 0 - self.current_file_index = 0 - self.total_file = 0 - self.current_file = None - self.voc_size = voc_size - self.max_seq_len = max_seq_len - self.pad_id = self.vocab["[PAD]"] - self.cls_id = self.vocab["[CLS]"] - self.sep_id = self.vocab["[SEP]"] - self.mask_id = self.vocab["[MASK]"] - self.is_test = is_test - self.generate_neg_sample = generate_neg_sample - if self.in_tokens: - assert self.batch_size >= self.max_seq_len, "The number of " \ - "tokens in batch should not be smaller than max seq length." - - if self.is_test: - self.epoch = 1 - self.shuffle_files = False - - def get_progress(self): - """return current progress of traning data - """ - return self.current_epoch, self.current_file_index, self.total_file, self.current_file - - def parse_line(self, line, max_seq_len=512): - """ parse one line to token_ids, sentence_ids, pos_ids, label - """ - line = line.strip().decode().split(";") - assert len(line) == 4, "One sample must have 4 fields!" - (token_ids, sent_ids, pos_ids, label) = line - token_ids = [int(token) for token in token_ids.split(" ")] - sent_ids = [int(token) for token in sent_ids.split(" ")] - pos_ids = [int(token) for token in pos_ids.split(" ")] - assert len(token_ids) == len(sent_ids) == len( - pos_ids - ), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)" - label = int(label) - if len(token_ids) > max_seq_len: - return None - return [token_ids, sent_ids, pos_ids, label] - - def read_file(self, file): - assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file - file_path = self.data_dir + "/" + file - with gzip.open(file_path, "rb") as f: - for line in f: - parsed_line = self.parse_line( - line, max_seq_len=self.max_seq_len) - if parsed_line is None: - continue - yield parsed_line - - def convert_to_unicode(self, text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): - return text - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") - - def load_vocab(self, vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - fin = open(vocab_file) - for num, line in enumerate(fin): - items = self.convert_to_unicode(line.strip()).split("\t") - if len(items) > 2: - break - token = items[0] - index = items[1] if len(items) == 2 else num - token = token.strip() - vocab[token] = int(index) - return vocab - - def random_pair_neg_samples(self, pos_samples): - """ randomly generate negtive samples using pos_samples - - Args: - pos_samples: list of positive samples - - Returns: - neg_samples: list of negtive samples - """ - np.random.shuffle(pos_samples) - num_sample = len(pos_samples) - neg_samples = [] - miss_num = 0 - - for i in range(num_sample): - pair_index = (i + 1) % num_sample - origin_src_ids = pos_samples[i][0] - origin_sep_index = origin_src_ids.index(2) - pair_src_ids = pos_samples[pair_index][0] - pair_sep_index = pair_src_ids.index(2) - - src_ids = origin_src_ids[:origin_sep_index + 1] + pair_src_ids[ - pair_sep_index + 1:] - if len(src_ids) >= self.max_seq_len: - miss_num += 1 - continue - sent_ids = [0] * len(origin_src_ids[:origin_sep_index + 1]) + [ - 1 - ] * len(pair_src_ids[pair_sep_index + 1:]) - pos_ids = list(range(len(src_ids))) - neg_sample = [src_ids, sent_ids, pos_ids, 0] - assert len(src_ids) == len(sent_ids) == len( - pos_ids - ), "[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True" - neg_samples.append(neg_sample) - return neg_samples, miss_num - - def mixin_negtive_samples(self, pos_sample_generator, buffer=1000): - """ 1. generate negtive samples by randomly group sentence_1 and sentence_2 of positive samples - 2. combine negtive samples and positive samples - - Args: - pos_sample_generator: a generator producing a parsed positive sample, which is a list: [token_ids, sent_ids, pos_ids, 1] - - Returns: - sample: one sample from shuffled positive samples and negtive samples - """ - pos_samples = [] - num_total_miss = 0 - pos_sample_num = 0 - try: - while True: - while len(pos_samples) < buffer: - pos_sample = next(pos_sample_generator) - label = pos_sample[3] - assert label == 1, "positive sample's label must be 1" - pos_samples.append(pos_sample) - pos_sample_num += 1 - - neg_samples, miss_num = self.random_pair_neg_samples( - pos_samples) - num_total_miss += miss_num - samples = pos_samples + neg_samples - pos_samples = [] - np.random.shuffle(samples) - for sample in samples: - yield sample - except StopIteration: - print("stopiteration: reach end of file") - if len(pos_samples) == 1: - yield pos_samples[0] - elif len(pos_samples) == 0: - yield None - else: - neg_samples, miss_num = self.random_pair_neg_samples( - pos_samples) - num_total_miss += miss_num - samples = pos_samples + neg_samples - pos_samples = [] - np.random.shuffle(samples) - for sample in samples: - yield sample - print("miss_num:%d\tideal_total_sample_num:%d\tmiss_rate:%f" % - (num_total_miss, pos_sample_num * 2, - num_total_miss / (pos_sample_num * 2))) - - def data_generator(self): - """ - data_generator - """ - files = os.listdir(self.data_dir) - self.total_file = len(files) - assert self.total_file > 0, "[Error] data_dir is empty" - - def wrapper(): - def reader(): - for epoch in range(self.epoch): - self.current_epoch = epoch + 1 - if self.shuffle_files: - np.random.shuffle(files) - for index, file in enumerate(files): - self.current_file_index = index + 1 - self.current_file = file - sample_generator = self.read_file(file) - if not self.is_test and self.generate_neg_sample: - sample_generator = self.mixin_negtive_samples( - sample_generator) - for sample in sample_generator: - if sample is None: - continue - yield sample - - def batch_reader(reader, batch_size, in_tokens): - batch, total_token_num, max_len = [], 0, 0 - for parsed_line in reader(): - token_ids, sent_ids, pos_ids, label = parsed_line - max_len = max(max_len, len(token_ids)) - if in_tokens: - to_append = (len(batch) + 1) * max_len <= batch_size - else: - to_append = len(batch) < batch_size - if to_append: - batch.append(parsed_line) - total_token_num += len(token_ids) - else: - yield batch, total_token_num - batch, total_token_num, max_len = [parsed_line], len( - token_ids), len(token_ids) - - if len(batch) > 0: - yield batch, total_token_num - - for batch_data, total_token_num in batch_reader( - reader, self.batch_size, self.in_tokens): - yield prepare_batch_data( - batch_data, - total_token_num, - voc_size=self.voc_size, - pad_id=self.pad_id, - cls_id=self.cls_id, - sep_id=self.sep_id, - mask_id=self.mask_id, - return_input_mask=True, - return_max_len=False, - return_num_token=False) - - return wrapper - - -if __name__ == "__main__": - pass diff --git a/paddleslim/teachers/bert/reader/squad.py b/paddleslim/teachers/bert/reader/squad.py deleted file mode 100644 index 651c46f966d228e626cdd25e1fb73809801716d0..0000000000000000000000000000000000000000 --- a/paddleslim/teachers/bert/reader/squad.py +++ /dev/null @@ -1,935 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run BERT on SQuAD 1.1 and SQuAD 2.0.""" - -import six -import math -import json -import random -import collections -import tokenization -from batching import prepare_batch_data - - -class SquadExample(object): - """A single training/test example for simple sequence classification. - - For examples without an answer, the start and end position are -1. - """ - - def __init__(self, - qas_id, - question_text, - doc_tokens, - orig_answer_text=None, - start_position=None, - end_position=None, - is_impossible=False): - self.qas_id = qas_id - self.question_text = question_text - self.doc_tokens = doc_tokens - self.orig_answer_text = orig_answer_text - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible - - def __str__(self): - return self.__repr__() - - def __repr__(self): - s = "" - s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) - s += ", question_text: %s" % ( - tokenization.printable_text(self.question_text)) - s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) - if self.start_position: - s += ", start_position: %d" % (self.start_position) - if self.start_position: - s += ", end_position: %d" % (self.end_position) - if self.start_position: - s += ", is_impossible: %r" % (self.is_impossible) - return s - - -class InputFeatures(object): - """A single set of features of data.""" - - def __init__(self, - unique_id, - example_index, - doc_span_index, - tokens, - token_to_orig_map, - token_is_max_context, - input_ids, - input_mask, - segment_ids, - start_position=None, - end_position=None, - is_impossible=None): - self.unique_id = unique_id - self.example_index = example_index - self.doc_span_index = doc_span_index - self.tokens = tokens - self.token_to_orig_map = token_to_orig_map - self.token_is_max_context = token_is_max_context - self.input_ids = input_ids - self.input_mask = input_mask - self.segment_ids = segment_ids - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible - - -def read_squad_examples(input_file, is_training, - version_2_with_negative=False): - """Read a SQuAD json file into a list of SquadExample.""" - with open(input_file, "r") as reader: - input_data = json.load(reader)["data"] - - def is_whitespace(c): - if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: - return True - return False - - examples = [] - for entry in input_data: - for paragraph in entry["paragraphs"]: - paragraph_text = paragraph["context"] - doc_tokens = [] - char_to_word_offset = [] - prev_is_whitespace = True - for c in paragraph_text: - if is_whitespace(c): - prev_is_whitespace = True - else: - if prev_is_whitespace: - doc_tokens.append(c) - else: - doc_tokens[-1] += c - prev_is_whitespace = False - char_to_word_offset.append(len(doc_tokens) - 1) - - for qa in paragraph["qas"]: - qas_id = qa["id"] - question_text = qa["question"] - start_position = None - end_position = None - orig_answer_text = None - is_impossible = False - if is_training: - - if version_2_with_negative: - is_impossible = qa["is_impossible"] - if (len(qa["answers"]) != 1) and (not is_impossible): - raise ValueError( - "For training, each question should have exactly 1 answer." - ) - if not is_impossible: - answer = qa["answers"][0] - orig_answer_text = answer["text"] - answer_offset = answer["answer_start"] - answer_length = len(orig_answer_text) - start_position = char_to_word_offset[answer_offset] - end_position = char_to_word_offset[answer_offset + - answer_length - 1] - # Only add answers where the text can be exactly recovered from the - # document. If this CAN'T happen it's likely due to weird Unicode - # stuff so we will just skip the example. - # - # Note that this means for training mode, every example is NOT - # guaranteed to be preserved. - actual_text = " ".join(doc_tokens[start_position:( - end_position + 1)]) - cleaned_answer_text = " ".join( - tokenization.whitespace_tokenize(orig_answer_text)) - if actual_text.find(cleaned_answer_text) == -1: - print("Could not find answer: '%s' vs. '%s'", - actual_text, cleaned_answer_text) - continue - else: - start_position = -1 - end_position = -1 - orig_answer_text = "" - - example = SquadExample( - qas_id=qas_id, - question_text=question_text, - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_position, - end_position=end_position, - is_impossible=is_impossible) - examples.append(example) - - return examples - - -def convert_examples_to_features( - examples, - tokenizer, - max_seq_length, - doc_stride, - max_query_length, - is_training, - #output_fn -): - """Loads a data file into a list of `InputBatch`s.""" - - unique_id = 1000000000 - - for (example_index, example) in enumerate(examples): - query_tokens = tokenizer.tokenize(example.question_text) - - if len(query_tokens) > max_query_length: - query_tokens = query_tokens[0:max_query_length] - - tok_to_orig_index = [] - orig_to_tok_index = [] - all_doc_tokens = [] - for (i, token) in enumerate(example.doc_tokens): - orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer.tokenize(token) - for sub_token in sub_tokens: - tok_to_orig_index.append(i) - all_doc_tokens.append(sub_token) - - tok_start_position = None - tok_end_position = None - if is_training and example.is_impossible: - tok_start_position = -1 - tok_end_position = -1 - if is_training and not example.is_impossible: - tok_start_position = orig_to_tok_index[example.start_position] - if example.end_position < len(example.doc_tokens) - 1: - tok_end_position = orig_to_tok_index[example.end_position + - 1] - 1 - else: - tok_end_position = len(all_doc_tokens) - 1 - (tok_start_position, tok_end_position) = _improve_answer_span( - all_doc_tokens, tok_start_position, tok_end_position, - tokenizer, example.orig_answer_text) - - # The -3 accounts for [CLS], [SEP] and [SEP] - max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 - - # We can have documents that are longer than the maximum sequence length. - # To deal with this we do a sliding window approach, where we take chunks - # of the up to our max length with a stride of `doc_stride`. - _DocSpan = collections.namedtuple( # pylint: disable=invalid-name - "DocSpan", ["start", "length"]) - doc_spans = [] - start_offset = 0 - while start_offset < len(all_doc_tokens): - length = len(all_doc_tokens) - start_offset - if length > max_tokens_for_doc: - length = max_tokens_for_doc - doc_spans.append(_DocSpan(start=start_offset, length=length)) - if start_offset + length == len(all_doc_tokens): - break - start_offset += min(length, doc_stride) - - for (doc_span_index, doc_span) in enumerate(doc_spans): - tokens = [] - token_to_orig_map = {} - token_is_max_context = {} - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in query_tokens: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - - for i in range(doc_span.length): - split_token_index = doc_span.start + i - token_to_orig_map[len(tokens)] = tok_to_orig_index[ - split_token_index] - - is_max_context = _check_is_max_context( - doc_spans, doc_span_index, split_token_index) - token_is_max_context[len(tokens)] = is_max_context - tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) - - input_ids = tokenizer.convert_tokens_to_ids(tokens) - - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - - # Zero-pad up to the sequence length. - #while len(input_ids) < max_seq_length: - # input_ids.append(0) - # input_mask.append(0) - # segment_ids.append(0) - - #assert len(input_ids) == max_seq_length - #assert len(input_mask) == max_seq_length - #assert len(segment_ids) == max_seq_length - - start_position = None - end_position = None - if is_training and not example.is_impossible: - # For training, if our document chunk does not contain an annotation - # we throw it out, since there is nothing to predict. - doc_start = doc_span.start - doc_end = doc_span.start + doc_span.length - 1 - out_of_span = False - if not (tok_start_position >= doc_start and - tok_end_position <= doc_end): - out_of_span = True - if out_of_span: - start_position = 0 - end_position = 0 - else: - doc_offset = len(query_tokens) + 2 - start_position = tok_start_position - doc_start + doc_offset - end_position = tok_end_position - doc_start + doc_offset - - if is_training and example.is_impossible: - start_position = 0 - end_position = 0 - """ - if example_index < 3: - print("*** Example ***") - print("unique_id: %s" % (unique_id)) - print("example_index: %s" % (example_index)) - print("doc_span_index: %s" % (doc_span_index)) - print("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) - print("token_to_orig_map: %s" % " ".join([ - "%d:%d" % (x, y) - for (x, y) in six.iteritems(token_to_orig_map) - ])) - print("token_is_max_context: %s" % " ".join([ - "%d:%s" % (x, y) - for (x, y) in six.iteritems(token_is_max_context) - ])) - print("input_ids: %s" % " ".join([str(x) for x in input_ids])) - print("input_mask: %s" % " ".join([str(x) for x in input_mask])) - print("segment_ids: %s" % - " ".join([str(x) for x in segment_ids])) - if is_training and example.is_impossible: - print("impossible example") - if is_training and not example.is_impossible: - answer_text = " ".join(tokens[start_position:(end_position + - 1)]) - print("start_position: %d" % (start_position)) - print("end_position: %d" % (end_position)) - print("answer: %s" % - (tokenization.printable_text(answer_text))) - """ - - feature = InputFeatures( - unique_id=unique_id, - example_index=example_index, - doc_span_index=doc_span_index, - tokens=tokens, - token_to_orig_map=token_to_orig_map, - token_is_max_context=token_is_max_context, - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - start_position=start_position, - end_position=end_position, - is_impossible=example.is_impossible) - - unique_id += 1 - - yield feature - - -def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, - orig_answer_text): - """Returns tokenized answer spans that better match the annotated answer.""" - - # The SQuAD annotations are character based. We first project them to - # whitespace-tokenized words. But then after WordPiece tokenization, we can - # often find a "better match". For example: - # - # Question: What year was John Smith born? - # Context: The leader was John Smith (1895-1943). - # Answer: 1895 - # - # The original whitespace-tokenized answer will be "(1895-1943).". However - # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match - # the exact answer, 1895. - # - # However, this is not always possible. Consider the following: - # - # Question: What country is the top exporter of electornics? - # Context: The Japanese electronics industry is the lagest in the world. - # Answer: Japan - # - # In this case, the annotator chose "Japan" as a character sub-span of - # the word "Japanese". Since our WordPiece tokenizer does not split - # "Japanese", we just use "Japanese" as the annotation. This is fairly rare - # in SQuAD, but does happen. - tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) - - for new_start in range(input_start, input_end + 1): - for new_end in range(input_end, new_start - 1, -1): - text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) - if text_span == tok_answer_text: - return (new_start, new_end) - - return (input_start, input_end) - - -def _check_is_max_context(doc_spans, cur_span_index, position): - """Check if this is the 'max context' doc span for the token.""" - - # Because of the sliding window approach taken to scoring documents, a single - # token can appear in multiple documents. E.g. - # Doc: the man went to the store and bought a gallon of milk - # Span A: the man went to the - # Span B: to the store and bought - # Span C: and bought a gallon of - # ... - # - # Now the word 'bought' will have two scores from spans B and C. We only - # want to consider the score with "maximum context", which we define as - # the *minimum* of its left and right context (the *sum* of left and - # right context will always be the same, of course). - # - # In the example the maximum context for 'bought' would be span C since - # it has 1 left context and 3 right context, while span B has 4 left context - # and 0 right context. - best_score = None - best_span_index = None - for (span_index, doc_span) in enumerate(doc_spans): - end = doc_span.start + doc_span.length - 1 - if position < doc_span.start: - continue - if position > end: - continue - num_left_context = position - doc_span.start - num_right_context = end - position - score = min(num_left_context, - num_right_context) + 0.01 * doc_span.length - if best_score is None or score > best_score: - best_score = score - best_span_index = span_index - - return cur_span_index == best_span_index - - -class DataProcessor(object): - def __init__(self, vocab_path, do_lower_case, max_seq_length, in_tokens, - doc_stride, max_query_length): - self._tokenizer = tokenization.FullTokenizer( - vocab_file=vocab_path, do_lower_case=do_lower_case) - self._max_seq_length = max_seq_length - self._doc_stride = doc_stride - self._max_query_length = max_query_length - self._in_tokens = in_tokens - - self.vocab = self._tokenizer.vocab - self.vocab_size = len(self.vocab) - self.pad_id = self.vocab["[PAD]"] - self.cls_id = self.vocab["[CLS]"] - self.sep_id = self.vocab["[SEP]"] - self.mask_id = self.vocab["[MASK]"] - - self.current_train_example = -1 - self.num_train_examples = -1 - self.current_train_epoch = -1 - - self.train_examples = None - self.predict_examples = None - self.num_examples = {'train': -1, 'predict': -1} - - def get_train_progress(self): - """Gets progress for training phase.""" - return self.current_train_example, self.current_train_epoch - - def get_examples(self, - data_path, - is_training, - version_2_with_negative=False): - examples = read_squad_examples( - input_file=data_path, - is_training=is_training, - version_2_with_negative=version_2_with_negative) - return examples - - def get_num_examples(self, phase): - if phase not in ['train', 'predict']: - raise ValueError( - "Unknown phase, which should be in ['train', 'predict'].") - return self.num_examples[phase] - - def get_features(self, examples, is_training): - features = convert_examples_to_features( - examples=examples, - tokenizer=self._tokenizer, - max_seq_length=self._max_seq_length, - doc_stride=self._doc_stride, - max_query_length=self._max_query_length, - is_training=is_training) - return features - - def data_generator(self, - data_path, - batch_size, - phase='train', - shuffle=False, - dev_count=1, - version_2_with_negative=False, - epoch=1): - if phase == 'train': - self.train_examples = self.get_examples( - data_path, - is_training=True, - version_2_with_negative=version_2_with_negative) - examples = self.train_examples - self.num_examples['train'] = len(self.train_examples) - elif phase == 'predict': - self.predict_examples = self.get_examples( - data_path, - is_training=False, - version_2_with_negative=version_2_with_negative) - examples = self.predict_examples - self.num_examples['predict'] = len(self.predict_examples) - else: - raise ValueError( - "Unknown phase, which should be in ['train', 'predict'].") - - def batch_reader(features, batch_size, in_tokens): - batch, total_token_num, max_len = [], 0, 0 - for (index, feature) in enumerate(features): - if phase == 'train': - self.current_train_example = index + 1 - seq_len = len(feature.input_ids) - labels = [feature.unique_id - ] if feature.start_position is None else [ - feature.start_position, feature.end_position - ] - example = [ - feature.input_ids, feature.segment_ids, range(seq_len) - ] + labels - max_len = max(max_len, seq_len) - - #max_len = max(max_len, len(token_ids)) - if in_tokens: - to_append = (len(batch) + 1) * max_len <= batch_size - else: - to_append = len(batch) < batch_size - - if to_append: - batch.append(example) - total_token_num += seq_len - else: - yield batch, total_token_num - batch, total_token_num, max_len = [example - ], seq_len, seq_len - if len(batch) > 0: - yield batch, total_token_num - - def wrapper(): - for epoch_index in range(epoch): - if shuffle: - random.shuffle(examples) - if phase == 'train': - self.current_train_epoch = epoch_index - features = self.get_features(examples, is_training=True) - else: - features = self.get_features(examples, is_training=False) - - all_dev_batches = [] - for batch_data, total_token_num in batch_reader( - features, batch_size, self._in_tokens): - batch_data = prepare_batch_data( - batch_data, - total_token_num, - voc_size=-1, - pad_id=self.pad_id, - cls_id=self.cls_id, - sep_id=self.sep_id, - mask_id=-1, - return_input_mask=True, - return_max_len=False, - return_num_token=False) - if len(all_dev_batches) < dev_count: - all_dev_batches.append(batch_data) - - if len(all_dev_batches) == dev_count: - for batch in all_dev_batches: - yield batch - all_dev_batches = [] - - return wrapper - - -def write_predictions(all_examples, all_features, all_results, n_best_size, - max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file, - version_2_with_negative, null_score_diff_threshold, - verbose): - """Write final predictions to the json file and log-odds of null if needed.""" - print("Writing predictions to: %s" % (output_prediction_file)) - print("Writing nbest to: %s" % (output_nbest_file)) - - example_index_to_features = collections.defaultdict(list) - for feature in all_features: - example_index_to_features[feature.example_index].append(feature) - - unique_id_to_result = {} - for result in all_results: - unique_id_to_result[result.unique_id] = result - - _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name - "PrelimPrediction", [ - "feature_index", "start_index", "end_index", "start_logit", - "end_logit" - ]) - - all_predictions = collections.OrderedDict() - all_nbest_json = collections.OrderedDict() - scores_diff_json = collections.OrderedDict() - - for (example_index, example) in enumerate(all_examples): - features = example_index_to_features[example_index] - - prelim_predictions = [] - # keep track of the minimum score of null start+end of position 0 - score_null = 1000000 # large and positive - min_null_feature_index = 0 # the paragraph slice with min mull score - null_start_logit = 0 # the start logit at the slice with min null score - null_end_logit = 0 # the end logit at the slice with min null score - for (feature_index, feature) in enumerate(features): - result = unique_id_to_result[feature.unique_id] - start_indexes = _get_best_indexes(result.start_logits, n_best_size) - end_indexes = _get_best_indexes(result.end_logits, n_best_size) - # if we could have irrelevant answers, get the min score of irrelevant - if version_2_with_negative: - feature_null_score = result.start_logits[ - 0] + result.end_logits[0] - if feature_null_score < score_null: - score_null = feature_null_score - min_null_feature_index = feature_index - null_start_logit = result.start_logits[0] - null_end_logit = result.end_logits[0] - for start_index in start_indexes: - for end_index in end_indexes: - # We could hypothetically create invalid predictions, e.g., predict - # that the start of the span is in the question. We throw out all - # invalid predictions. - if start_index >= len(feature.tokens): - continue - if end_index >= len(feature.tokens): - continue - if start_index not in feature.token_to_orig_map: - continue - if end_index not in feature.token_to_orig_map: - continue - if not feature.token_is_max_context.get(start_index, - False): - continue - if end_index < start_index: - continue - length = end_index - start_index + 1 - if length > max_answer_length: - continue - prelim_predictions.append( - _PrelimPrediction( - feature_index=feature_index, - start_index=start_index, - end_index=end_index, - start_logit=result.start_logits[start_index], - end_logit=result.end_logits[end_index])) - - if version_2_with_negative: - prelim_predictions.append( - _PrelimPrediction( - feature_index=min_null_feature_index, - start_index=0, - end_index=0, - start_logit=null_start_logit, - end_logit=null_end_logit)) - prelim_predictions = sorted( - prelim_predictions, - key=lambda x: (x.start_logit + x.end_logit), - reverse=True) - - _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name - "NbestPrediction", ["text", "start_logit", "end_logit"]) - - seen_predictions = {} - nbest = [] - for pred in prelim_predictions: - if len(nbest) >= n_best_size: - break - feature = features[pred.feature_index] - if pred.start_index > 0: # this is a non-null prediction - tok_tokens = feature.tokens[pred.start_index:(pred.end_index + - 1)] - orig_doc_start = feature.token_to_orig_map[pred.start_index] - orig_doc_end = feature.token_to_orig_map[pred.end_index] - orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + - 1)] - tok_text = " ".join(tok_tokens) - - # De-tokenize WordPieces that have been split off. - tok_text = tok_text.replace(" ##", "") - tok_text = tok_text.replace("##", "") - - # Clean whitespace - tok_text = tok_text.strip() - tok_text = " ".join(tok_text.split()) - orig_text = " ".join(orig_tokens) - - final_text = get_final_text(tok_text, orig_text, do_lower_case, - verbose) - if final_text in seen_predictions: - continue - - seen_predictions[final_text] = True - else: - final_text = "" - seen_predictions[final_text] = True - - nbest.append( - _NbestPrediction( - text=final_text, - start_logit=pred.start_logit, - end_logit=pred.end_logit)) - - # if we didn't inlude the empty option in the n-best, inlcude it - if version_2_with_negative: - if "" not in seen_predictions: - nbest.append( - _NbestPrediction( - text="", - start_logit=null_start_logit, - end_logit=null_end_logit)) - # In very rare edge cases we could have no valid predictions. So we - # just create a nonce prediction in this case to avoid failure. - if not nbest: - nbest.append( - _NbestPrediction( - text="empty", start_logit=0.0, end_logit=0.0)) - - assert len(nbest) >= 1 - - total_scores = [] - best_non_null_entry = None - for entry in nbest: - total_scores.append(entry.start_logit + entry.end_logit) - if not best_non_null_entry: - if entry.text: - best_non_null_entry = entry - # debug - if best_non_null_entry is None: - print("Emmm..., sth wrong") - - probs = _compute_softmax(total_scores) - - nbest_json = [] - for (i, entry) in enumerate(nbest): - output = collections.OrderedDict() - output["text"] = entry.text - output["probability"] = probs[i] - output["start_logit"] = entry.start_logit - output["end_logit"] = entry.end_logit - nbest_json.append(output) - - assert len(nbest_json) >= 1 - - if not version_2_with_negative: - all_predictions[example.qas_id] = nbest_json[0]["text"] - else: - # predict "" iff the null score - the score of best non-null > threshold - score_diff = score_null - best_non_null_entry.start_logit - ( - best_non_null_entry.end_logit) - scores_diff_json[example.qas_id] = score_diff - if score_diff > null_score_diff_threshold: - all_predictions[example.qas_id] = "" - else: - all_predictions[example.qas_id] = best_non_null_entry.text - - all_nbest_json[example.qas_id] = nbest_json - - with open(output_prediction_file, "w") as writer: - writer.write(json.dumps(all_predictions, indent=4) + "\n") - - with open(output_nbest_file, "w") as writer: - writer.write(json.dumps(all_nbest_json, indent=4) + "\n") - - if version_2_with_negative: - with open(output_null_log_odds_file, "w") as writer: - writer.write(json.dumps(scores_diff_json, indent=4) + "\n") - - -def get_final_text(pred_text, orig_text, do_lower_case, verbose): - """Project the tokenized prediction back to the original text.""" - - # When we created the data, we kept track of the alignment between original - # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So - # now `orig_text` contains the span of our original text corresponding to the - # span that we predicted. - # - # However, `orig_text` may contain extra characters that we don't want in - # our prediction. - # - # For example, let's say: - # pred_text = steve smith - # orig_text = Steve Smith's - # - # We don't want to return `orig_text` because it contains the extra "'s". - # - # We don't want to return `pred_text` because it's already been normalized - # (the SQuAD eval script also does punctuation stripping/lower casing but - # our tokenizer does additional normalization like stripping accent - # characters). - # - # What we really want to return is "Steve Smith". - # - # Therefore, we have to apply a semi-complicated alignment heruistic between - # `pred_text` and `orig_text` to get a character-to-charcter alignment. This - # can fail in certain cases in which case we just return `orig_text`. - - def _strip_spaces(text): - ns_chars = [] - ns_to_s_map = collections.OrderedDict() - for (i, c) in enumerate(text): - if c == " ": - continue - ns_to_s_map[len(ns_chars)] = i - ns_chars.append(c) - ns_text = "".join(ns_chars) - return (ns_text, ns_to_s_map) - - # We first tokenize `orig_text`, strip whitespace from the result - # and `pred_text`, and check if they are the same length. If they are - # NOT the same length, the heuristic has failed. If they are the same - # length, we assume the characters are one-to-one aligned. - tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) - - tok_text = " ".join(tokenizer.tokenize(orig_text)) - - start_position = tok_text.find(pred_text) - if start_position == -1: - if verbose: - print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) - return orig_text - end_position = start_position + len(pred_text) - 1 - - (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) - (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) - - if len(orig_ns_text) != len(tok_ns_text): - if verbose: - print("Length not equal after stripping spaces: '%s' vs '%s'", - orig_ns_text, tok_ns_text) - return orig_text - - # We then project the characters in `pred_text` back to `orig_text` using - # the character-to-character alignment. - tok_s_to_ns_map = {} - for (i, tok_index) in six.iteritems(tok_ns_to_s_map): - tok_s_to_ns_map[tok_index] = i - - orig_start_position = None - if start_position in tok_s_to_ns_map: - ns_start_position = tok_s_to_ns_map[start_position] - if ns_start_position in orig_ns_to_s_map: - orig_start_position = orig_ns_to_s_map[ns_start_position] - - if orig_start_position is None: - if verbose: - print("Couldn't map start position") - return orig_text - - orig_end_position = None - if end_position in tok_s_to_ns_map: - ns_end_position = tok_s_to_ns_map[end_position] - if ns_end_position in orig_ns_to_s_map: - orig_end_position = orig_ns_to_s_map[ns_end_position] - - if orig_end_position is None: - if verbose: - print("Couldn't map end position") - return orig_text - - output_text = orig_text[orig_start_position:(orig_end_position + 1)] - return output_text - - -def _get_best_indexes(logits, n_best_size): - """Get the n-best logits from a list.""" - index_and_score = sorted( - enumerate(logits), key=lambda x: x[1], reverse=True) - - best_indexes = [] - for i in range(len(index_and_score)): - if i >= n_best_size: - break - best_indexes.append(index_and_score[i][0]) - return best_indexes - - -def _compute_softmax(scores): - """Compute softmax probability over raw logits.""" - if not scores: - return [] - - max_score = None - for score in scores: - if max_score is None or score > max_score: - max_score = score - - exp_scores = [] - total_sum = 0.0 - for score in scores: - x = math.exp(score - max_score) - exp_scores.append(x) - total_sum += x - - probs = [] - for score in exp_scores: - probs.append(score / total_sum) - return probs - - -if __name__ == '__main__': - train_file = 'squad/train-v1.1.json' - vocab_file = 'uncased_L-12_H-768_A-12/vocab.txt' - do_lower_case = True - tokenizer = tokenization.FullTokenizer( - vocab_file=vocab_file, do_lower_case=do_lower_case) - train_examples = read_squad_examples( - input_file=train_file, is_training=True) - print("begin converting") - for (index, feature) in enumerate( - convert_examples_to_features( - examples=train_examples, - tokenizer=tokenizer, - max_seq_length=384, - doc_stride=128, - max_query_length=64, - is_training=True, - #output_fn=train_writer.process_feature - )): - if index < 10: - print(index, feature.input_ids, feature.input_mask, - feature.segment_ids) - #for (index, example) in enumerate(train_examples): - # if index < 5: - # print(example)