"""this file is a copy of https://github.com/zihangdai/xlnet""" import re import numpy as np from data_utils import SEP_ID, CLS_ID SEG_ID_A = 0 SEG_ID_B = 1 SEG_ID_CLS = 2 SEG_ID_SEP = 3 SEG_ID_PAD = 4 class PaddingInputExample(object): """Fake example so the num input examples is a multiple of the batch size. When running eval/predict on the TPU, we need to pad the number of examples to be a multiple of the batch size, because the TPU requires a fixed batch size. The alternative is to drop the last batch, which is bad because it means the entire output data won't be generated. We use this class instead of `None` because treating `None` as padding battches could cause silent errors. """ class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_id, is_real_example=True): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_id = label_id self.is_real_example = is_real_example def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def convert_single_example(ex_index, example, label_list, max_seq_length, tokenize_fn): """Converts a single `InputExample` into a single `InputFeatures`.""" if isinstance(example, PaddingInputExample): return InputFeatures( input_ids=[0] * max_seq_length, input_mask=[1] * max_seq_length, segment_ids=[0] * max_seq_length, label_id=0, is_real_example=False) if label_list is not None: label_map = {} for (i, label) in enumerate(label_list): label_map[label] = i tokens_a = tokenize_fn(example.text_a) tokens_b = None if example.text_b: tokens_b = tokenize_fn(example.text_b) if tokens_b: # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for two [SEP] & one [CLS] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) else: # Account for one [SEP] & one [CLS] with "- 2" if len(tokens_a) > max_seq_length - 2: tokens_a = tokens_a[:max_seq_length - 2] tokens = [] segment_ids = [] for token in tokens_a: tokens.append(token) segment_ids.append(SEG_ID_A) tokens.append(SEP_ID) segment_ids.append(SEG_ID_A) if tokens_b: for token in tokens_b: tokens.append(token) segment_ids.append(SEG_ID_B) tokens.append(SEP_ID) segment_ids.append(SEG_ID_B) tokens.append(CLS_ID) segment_ids.append(SEG_ID_CLS) input_ids = tokens # The mask has 0 for real tokens and 1 for padding tokens. Only real # tokens are attended to. input_mask = [0] * len(input_ids) # Zero-pad up to the sequence length. if len(input_ids) < max_seq_length: delta_len = max_seq_length - len(input_ids) input_ids = [0] * delta_len + input_ids input_mask = [1] * delta_len + input_mask segment_ids = [SEG_ID_PAD] * delta_len + segment_ids assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length if label_list is not None: label_id = label_map[example.label] else: label_id = example.label if ex_index < 1: print("*** Example ***") print("guid: %s" % (example.guid)) print("input_ids: %s" % " ".join([str(x) for x in input_ids])) print("input_mask: %s" % " ".join([str(x) for x in input_mask])) print("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) print("label: {} (id = {})".format(example.label, label_id)) feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_id) return feature