5.1 KB
Newer Older
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
import sys
import numpy as np
import re
from propeller import log
import itertools
from import Dataset

import six

if six.PY2:
    import operator
    def accumulate(iterable, func=operator.add, initial=None):
        'Return running totals'
        # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
        # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
        # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
        it = iter(iterable)
        total = initial
        if initial is None:
                total = next(it)
            except StopIteration:
        yield total
        for element in it:
            total = func(total, element)
            yield total
    from itertools import accumulate

def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a peice of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens

def wordpiece(token, vocab, unk_token, sentencepiece_style_vocab=False):
    """call with single word"""
    chars = list(token)
    if len(chars) > max_input_chars_per_word:
        return [unk_token], [(0, len(chars))]

    is_bad = False
    start = 0
    sub_tokens = []
    sub_pos = []
    while start < len(chars):
        end = len(chars)
        cur_substr = None
        while start < end:
            substr = "".join(chars[start:end])
            if start == 0 and sentencepiece_style_vocab:
                substr = u'\u2581' + substr
            if start > 0 and not sentencepiece_style_vocab:
                substr = "##" + substr
            if substr in vocab:
                cur_substr = substr
            end -= 1
        if cur_substr is None:
            is_bad = True
        sub_pos.append((start, end))
        start = end
    if is_bad:
        return [unk_token], [(0, len(chars))]
        return sub_tokens, sub_pos

class SpaceTokenizer(object):
    def __init__(self, vocab, lower=True):
        char tokenizer (wordpiece english)
        normed txt(space seperated or not) => list of word-piece
        self.vocab = set(vocab)
        self.lower = lower

    def __call__(self, sen):
        if len(sen) == 0:
            return [] #empty line
        sen = sen.decode('utf8')
        if self.lower:
            sen = sen.lower()
        res = []
        for s in sen.split(' '):
            if s == ' ':
            if s in self.vocab:
        return res

class CharTokenizer(object):
    def __init__(self, vocab, lower=True):
        char tokenizer (wordpiece english)
        normed txt(space seperated or not) => list of word-piece
        self.vocab = set(vocab)
        #self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
chenxuyi 已提交
        self.pat =  re.compile(r'([a-zA-Z0-9]+|\S)')
chenxuyi 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        self.lower = lower

    def __call__(self, sen):
        if len(sen) == 0:
            return [] #empty line
        sen = sen.decode('utf8')
        if self.lower:
            sen = sen.lower()
        res = []
        for match in self.pat.finditer(sen):
            words, _ = wordpiece(, vocab=self.vocab, unk_token='[UNK]')
        return res

def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
    token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
    token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
    sen_emb = np.concatenate([[cls_id], seg_a, [sep_id], seg_b, [sep_id]], 0)
    token_type_emb = np.concatenate([[0], token_type_a, [0], token_type_b, [1]], 0)

    seqlen = sen_emb.shape[0]
    #random truncate
chenxuyi 已提交
    random_begin = 0 #np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
chenxuyi 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149
    sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
    token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]

    return sen_emb, token_type_emb

def build_1_pair(seg_a, max_seqlen, cls_id, sep_id):
    token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0

    sen_emb = np.concatenate([[cls_id], seg_a, [sep_id]], 0)
    token_type_emb = np.concatenate([[0], token_type_a, [0]], 0)

    seqlen = sen_emb.shape[0]
    #random truncate
chenxuyi 已提交
    random_begin = 0 #np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
chenxuyi 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

    sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
    token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]
    return sen_emb, token_type_emb

def expand_dims(*args):
    func = lambda i: np.expand_dims(i, -1)
    ret = [func(i) for i in args]
    return ret

def interleave(ds1, ds2):
    def gen():
        for i, j in six.moves.zip_longest(iter(ds1), iter(ds2)):
            if i is not None:
                yield i
            if j is not None:
                yield j
    return Dataset.from_generator_func(gen)