data.py 6.4 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6
import sys
import numpy as np
import re
from propeller import log
import itertools
from propeller.paddle.data import Dataset
C
chenxuyi 已提交
7
import pickle
C
chenxuyi 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

import six

if six.PY2:
    import operator
    def accumulate(iterable, func=operator.add, initial=None):
        'Return running totals'
        # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
        # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
        # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
        it = iter(iterable)
        total = initial
        if initial is None:
            try:
                total = next(it)
            except StopIteration:
                return
        yield total
        for element in it:
            total = func(total, element)
            yield total
else:
    from itertools import accumulate


max_input_chars_per_word=100
def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a peice of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens


def wordpiece(token, vocab, unk_token, sentencepiece_style_vocab=False):
    """call with single word"""
    chars = list(token)
    if len(chars) > max_input_chars_per_word:
        return [unk_token], [(0, len(chars))]

    is_bad = False
    start = 0
    sub_tokens = []
    sub_pos = []
    while start < len(chars):
        end = len(chars)
        cur_substr = None
        while start < end:
            substr = "".join(chars[start:end])
            if start == 0 and sentencepiece_style_vocab:
                substr = u'\u2581' + substr
            if start > 0 and not sentencepiece_style_vocab:
                substr = "##" + substr
            if substr in vocab:
                cur_substr = substr
                break
            end -= 1
        if cur_substr is None:
            is_bad = True
            break
        sub_tokens.append(cur_substr)
        sub_pos.append((start, end))
        start = end
    if is_bad:
        return [unk_token], [(0, len(chars))]
    else:
        return sub_tokens, sub_pos


class SpaceTokenizer(object):
    def __init__(self, vocab, lower=True):
        """
        char tokenizer (wordpiece english)
        normed txt(space seperated or not) => list of word-piece
        """
        self.vocab = set(vocab)
        self.lower = lower

    def __call__(self, sen):
        if len(sen) == 0:
            return [] #empty line
        sen = sen.decode('utf8')
        if self.lower:
            sen = sen.lower()
        res = []
        for s in sen.split(' '):
            if s == ' ':
                continue
            if s in self.vocab:
                res.append(s)
            else:
                res.append('[UNK]')
        return res


class CharTokenizer(object):
C
chenxuyi 已提交
105
    def __init__(self, vocab, lower=True, sentencepiece_style_vocab=False):
C
chenxuyi 已提交
106 107 108 109 110 111
        """
        char tokenizer (wordpiece english)
        normed txt(space seperated or not) => list of word-piece
        """
        self.vocab = set(vocab)
        #self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
C
chenxuyi 已提交
112
        self.pat =  re.compile(r'([a-zA-Z0-9]+|\S)')
C
chenxuyi 已提交
113
        self.lower = lower
C
chenxuyi 已提交
114
        self.sentencepiece_style_vocab = sentencepiece_style_vocab
C
chenxuyi 已提交
115 116 117 118 119 120 121 122 123

    def __call__(self, sen):
        if len(sen) == 0:
            return [] #empty line
        sen = sen.decode('utf8')
        if self.lower:
            sen = sen.lower()
        res = []
        for match in self.pat.finditer(sen):
C
chenxuyi 已提交
124
            words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]', sentencepiece_style_vocab=self.sentencepiece_style_vocab)
C
chenxuyi 已提交
125 126 127 128
            res.extend(words)
        return res


C
chenxuyi 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
class WSSPTokenizer(object):
    def __init__(self, sp_model_dir, word_dict, ws=True, lower=True):
        self.ws = ws
        self.lower = lower
        self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8')
        import sentencepiece as spm
        self.sp_model = spm.SentencePieceProcessor()
        self.window_size = 5
        self.sp_model.Load(sp_model_dir)

    def cut(self, chars):
        words = []
        idx = 0
        while idx < len(chars):
            matched = False
            for i in range(self.window_size, 0, -1):
                cand = chars[idx: idx+i]
                if cand in self.dict:
                    words.append(cand)
                    matched = True
                    break
            if not matched: 
                i = 1
                words.append(chars[idx])
            idx += i
        return words
 
    def __call__(self, sen):
        sen = sen.decode('utf8')
        if self.ws:
            sen = [s for s in self.cut(sen) if s != ' ']
        else:
            sen = sen.split(' ')
        if self.lower:
            sen = [s.lower() for s in sen]
        sen = ' '.join(sen)
        ret = self.sp_model.EncodeAsPieces(sen)
        return ret


C
chenxuyi 已提交
169 170 171 172 173 174 175 176
def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
    token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
    token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
    sen_emb = np.concatenate([[cls_id], seg_a, [sep_id], seg_b, [sep_id]], 0)
    token_type_emb = np.concatenate([[0], token_type_a, [0], token_type_b, [1]], 0)

    seqlen = sen_emb.shape[0]
    #random truncate
C
chenxuyi 已提交
177
    random_begin = 0 #np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
C
chenxuyi 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191
    sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
    token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]

    return sen_emb, token_type_emb


def build_1_pair(seg_a, max_seqlen, cls_id, sep_id):
    token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0

    sen_emb = np.concatenate([[cls_id], seg_a, [sep_id]], 0)
    token_type_emb = np.concatenate([[0], token_type_a, [0]], 0)

    seqlen = sen_emb.shape[0]
    #random truncate
C
chenxuyi 已提交
192
    random_begin = 0 #np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
C
chenxuyi 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213

    sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
    token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]
    return sen_emb, token_type_emb


def expand_dims(*args):
    func = lambda i: np.expand_dims(i, -1)
    ret = [func(i) for i in args]
    return ret


def interleave(ds1, ds2):
    def gen():
        for i, j in six.moves.zip_longest(iter(ds1), iter(ds2)):
            if i is not None:
                yield i
            if j is not None:
                yield j
    return Dataset.from_generator_func(gen)