prepro_utils.py 4.8 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unicodedata
import six
from functools import partial


SPIECE_UNDERLINE = '▁'


def printable_text(text):
0
0YuanZhang0 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    """Returns text encoded in a way suitable for print or `tf.logging`."""

    # These functions want `str` for both Python2 and Python3, but in one case
    # it's a Unicode string and in the other it's a byte string.
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text
        elif isinstance(text, unicode):
            return text.encode("utf-8")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
0
0YuanZhang0 已提交
33
    else:
0
0YuanZhang0 已提交
34
        raise ValueError("Not running on Python2 or Python 3?")
0
0YuanZhang0 已提交
35 36 37


def print_(*args):
0
0YuanZhang0 已提交
38 39 40 41 42 43 44 45 46
    new_args = []
    for arg in args:
        if isinstance(arg, list):
            s = [printable_text(i) for i in arg]
            s = ' '.join(s)
            new_args.append(s)
        else:
            new_args.append(printable_text(arg))
    print(*new_args)
0
0YuanZhang0 已提交
47 48 49


def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
0
0YuanZhang0 已提交
50 51 52 53 54
    if remove_space:
        outputs = ' '.join(inputs.strip().split())
    else:
        outputs = inputs
    outputs = outputs.replace("``", '"').replace("''", '"')
0
0YuanZhang0 已提交
55

0
0YuanZhang0 已提交
56 57
    if six.PY2 and isinstance(outputs, str):
        outputs = outputs.decode('utf-8')
0
0YuanZhang0 已提交
58

0
0YuanZhang0 已提交
59 60 61 62 63
    if not keep_accents:
        outputs = unicodedata.normalize('NFKD', outputs)
        outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
    if lower:
        outputs = outputs.lower()
0
0YuanZhang0 已提交
64

0
0YuanZhang0 已提交
65
    return outputs
0
0YuanZhang0 已提交
66 67 68


def encode_pieces(sp_model, text, return_unicode=True, sample=False):
0
0YuanZhang0 已提交
69 70 71 72 73 74 75 76
    # return_unicode is used only for py2

    # note(zhiliny): in some systems, sentencepiece only accepts str for py2
    if six.PY2 and isinstance(text, unicode):
        text = text.encode('utf-8')

    if not sample:
        pieces = sp_model.EncodeAsPieces(text)
0
0YuanZhang0 已提交
77
    else:
0
0YuanZhang0 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
    new_pieces = []
    for piece in pieces:
        if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
            cur_pieces = sp_model.EncodeAsPieces(
                    piece[:-1].replace(SPIECE_UNDERLINE, ''))
            if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
                if len(cur_pieces[0]) == 1:
                    cur_pieces = cur_pieces[1:]
                else:
                    cur_pieces[0] = cur_pieces[0][1:]
            cur_pieces.append(piece[-1])
            new_pieces.extend(cur_pieces)
        else:
            new_pieces.append(piece)
0
0YuanZhang0 已提交
93

0
0YuanZhang0 已提交
94 95 96 97 98 99 100 101
    # note(zhiliny): convert back to unicode for py2
    if six.PY2 and return_unicode:
        ret_pieces = []
        for piece in new_pieces:
            if isinstance(piece, str):
                piece = piece.decode('utf-8')
            ret_pieces.append(piece)
        new_pieces = ret_pieces
0
0YuanZhang0 已提交
102

0
0YuanZhang0 已提交
103
    return new_pieces
0
0YuanZhang0 已提交
104 105 106


def encode_ids(sp_model, text, sample=False):
0
0YuanZhang0 已提交
107 108 109
    pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
    ids = [sp_model.PieceToId(piece) for piece in pieces]
    return ids
0
0YuanZhang0 已提交
110 111 112


if __name__ == '__main__':
0
0YuanZhang0 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    import sentencepiece as spm

    sp = spm.SentencePieceProcessor()
    sp.load('sp10m.uncased.v3.model')

    print_(u'I was born in 2000, and this is falsé.')
    print_(u'ORIGINAL', sp.EncodeAsPieces(u'I was born in 2000, and this is falsé.'))
    print_(u'OURS', encode_pieces(sp, u'I was born in 2000, and this is falsé.'))
    print(encode_ids(sp, u'I was born in 2000, and this is falsé.'))
    print_('')
    prepro_func = partial(preprocess_text, lower=True)
    print_(prepro_func('I was born in 2000, and this is falsé.'))
    print_('ORIGINAL', sp.EncodeAsPieces(prepro_func('I was born in 2000, and this is falsé.')))
    print_('OURS', encode_pieces(sp, prepro_func('I was born in 2000, and this is falsé.')))
    print(encode_ids(sp, prepro_func('I was born in 2000, and this is falsé.')))
    print_('')
    print_('I was born in 2000, and this is falsé.')
    print_('ORIGINAL', sp.EncodeAsPieces('I was born in 2000, and this is falsé.'))
    print_('OURS', encode_pieces(sp, 'I was born in 2000, and this is falsé.'))
    print(encode_ids(sp, 'I was born in 2000, and this is falsé.'))
    print_('')
    print_('I was born in 92000, and this is falsé.')
    print_('ORIGINAL', sp.EncodeAsPieces('I was born in 92000, and this is falsé.'))
    print_('OURS', encode_pieces(sp, 'I was born in 92000, and this is falsé.'))
    print(encode_ids(sp, 'I was born in 92000, and this is falsé.'))