bert_reader.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
from batching import pad_batch_data
import tokenization

class BertReader():
    def __init__(self, vocab_file="", max_seq_len=128):
        self.vocab_file = vocab_file
        self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
        self.max_seq_len = max_seq_len
        self.vocab = self.tokenizer.vocab
        self.pad_id = self.vocab["[PAD]"]
        self.cls_id = self.vocab["[CLS]"]
        self.sep_id = self.vocab["[SEP]"]
        self.mask_id = self.vocab["[MASK]"]

    def pad_batch(self, token_ids, text_type_ids, position_ids):
        batch_token_ids = [token_ids]
        batch_text_type_ids = [text_type_ids]
        batch_position_ids = [position_ids]

        padded_token_ids, input_mask = pad_batch_data(
            batch_token_ids,
            max_seq_len=self.max_seq_len,
            pad_idx=self.pad_id,
            return_input_mask=True)
        padded_text_type_ids = pad_batch_data(
            batch_text_type_ids,
            max_seq_len=self.max_seq_len,
            pad_idx=self.pad_id)
        padded_position_ids = pad_batch_data(
            batch_position_ids,
            max_seq_len=self.max_seq_len,
            pad_idx=self.pad_id)
        return padded_token_ids, padded_position_ids, padded_text_type_ids, input_mask

    def process(self, sent):
        text_a = tokenization.convert_to_unicode(sent)
        tokens_a = self.tokenizer.tokenize(text_a)
        if len(tokens_a) > self.max_seq_len - 2:
            tokens_a = tokens_a[0:(self.max_seq_len - 2)]
        tokens = []
        text_type_ids = []
        tokens.append("[CLS]")
        text_type_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            text_type_ids.append(0)
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        position_ids = list(range(len(token_ids)))
        p_token_ids, p_pos_ids, p_text_type_ids, input_mask = \
            self.pad_batch(token_ids, text_type_ids, position_ids)
        feed_result = {"input_ids": p_token_ids.reshape(-1).tolist(),
                       "position_ids": p_pos_ids.reshape(-1).tolist(),
                       "segment_ids": p_text_type_ids.reshape(-1).tolist(),
                       "input_mask": input_mask.reshape(-1).tolist()}
        return feed_result