fastai_transforms.py 8.0 KB
Newer Older
1 2 3
'''
NLP data processing; tokenizes text and creates vocab indexes

4
I have directly COPIED AND PASTE OF THE TRANSFORMS.PY FASTAI LIBRARY.  I only
5 6
need the Tokenizer and the Vocab classes which are both in this module. This
way I avoid the numerous fastai dependencies.
7 8

Credit for the code here to Jeremy Howard and the fastai team
9 10
'''

11
from ..wdtypes import *
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

import sys
import os
import re
import html
import torch
import spacy

from concurrent.futures.process import ProcessPoolExecutor
from collections import Counter, defaultdict
from spacy.symbols import ORTH


def partition(a:Collection, sz:int)->List[Collection]:
    "Split iterables `a` in equal parts of size `sz`"
J
jrzaurin 已提交
27
    return [a[i:i+sz] for i in range(0, len(a), sz)] # type: ignore
28 29 30 31 32 33 34 35 36 37 38 39


def partition_by_cores(a:Collection, n_cpus:int)->List[Collection]:
    "Split data in `a` equally among `n_cpus` cores"
    return partition(a, len(a)//n_cpus + 1)


def ifnone(a:Any,b:Any)->Any:
    "`a` if `a` is not None, otherwise `b`."
    return b if a is None else a


J
jrzaurin 已提交
40
def num_cpus()->Optional[int]:
41
    "Get number of cpus"
J
jrzaurin 已提交
42 43 44 45
    try:
        return len(os.sched_getaffinity(0))
    except AttributeError:
        return os.cpu_count()
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

_default_cpus = min(16, num_cpus())
defaults = SimpleNamespace(cpus=_default_cpus, cmap='viridis', return_fig=False, silent=False)

__all__ = ['BaseTokenizer', 'SpacyTokenizer', 'Tokenizer', 'Vocab', 'fix_html', 'replace_all_caps', 'replace_rep', 'replace_wrep',
           'rm_useless_spaces', 'spec_add_spaces', 'BOS', 'EOS', 'FLD', 'UNK', 'PAD', 'TK_MAJ', 'TK_UP', 'TK_REP', 'TK_REP', 'TK_WREP',
           'deal_caps']

BOS,EOS,FLD,UNK,PAD = 'xxbos','xxeos','xxfld','xxunk','xxpad'
TK_MAJ,TK_UP,TK_REP,TK_WREP = 'xxmaj','xxup','xxrep','xxwrep'
defaults.text_spec_tok = [UNK,PAD,BOS,EOS,FLD,TK_MAJ,TK_UP,TK_REP,TK_WREP]


class BaseTokenizer():
    "Basic class for a tokenizer function."
    def __init__(self, lang:str):                      self.lang = lang
    def tokenizer(self, t:str) -> List[str]:           return t.split(' ')
    def add_special_cases(self, toks:Collection[str]): pass


class SpacyTokenizer(BaseTokenizer):
    "Wrapper around a spacy tokenizer to make it a `BaseTokenizer`."
    def __init__(self, lang:str):
        self.tok = spacy.blank(lang, disable=["parser","tagger","ner"])

    def tokenizer(self, t:str) -> List[str]:
        return [t.text for t in self.tok.tokenizer(t)]

    def add_special_cases(self, toks:Collection[str]):
        for w in toks:
            self.tok.tokenizer.add_special_case(w, [{ORTH: w}])


def spec_add_spaces(t:str) -> str:
    "Add spaces around / and # in `t`. \n"
    return re.sub(r'([/#\n])', r' \1 ', t)


def rm_useless_spaces(t:str) -> str:
    "Remove multiple spaces in `t`."
    return re.sub(' {2,}', ' ', t)


def replace_rep(t:str) -> str:
    "Replace repetitions at the character level in `t`."
J
jrzaurin 已提交
91
    def _replace_rep(m:Match[str]) -> str:
92 93 94 95 96 97 98 99
        c,cc = m.groups()
        return f' {TK_REP} {len(cc)+1} {c} '
    re_rep = re.compile(r'(\S)(\1{3,})')
    return re_rep.sub(_replace_rep, t)


def replace_wrep(t:str) -> str:
    "Replace word repetitions in `t`."
J
jrzaurin 已提交
100
    def _replace_wrep(m:Match[str]) -> str:
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        c,cc = m.groups()
        return f' {TK_WREP} {len(cc.split())+1} {c} '
    re_wrep = re.compile(r'(\b\w+\W+)(\1{3,})')
    return re_wrep.sub(_replace_wrep, t)


def fix_html(x:str) -> str:
    "List of replacements from html strings in `x`."
    re1 = re.compile(r'  +')
    x = x.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>',UNK).replace(' @.@ ','.').replace(
        ' @-@ ','-').replace(' @,@ ',',').replace('\\', ' \\ ')
    return re1.sub(' ', html.unescape(x))


def replace_all_caps(x:Collection[str]) -> Collection[str]:
    "Replace tokens in ALL CAPS in `x` by their lower version and add `TK_UP` before."
    res = []
    for t in x:
        if t.isupper() and len(t) > 1: res.append(TK_UP); res.append(t.lower())
        else: res.append(t)
    return res


def deal_caps(x:Collection[str]) -> Collection[str]:
    "Replace all Capitalized tokens in `x` by their lower version and add `TK_MAJ` before."
    res = []
    for t in x:
        if t == '': continue
        if t[0].isupper() and len(t) > 1 and t[1:].islower(): res.append(TK_MAJ)
        res.append(t.lower())
    return res

defaults.text_pre_rules = [fix_html, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces]
defaults.text_post_rules = [replace_all_caps, deal_caps]


class Tokenizer():
    "Put together rules and a tokenizer function to tokenize text with multiprocessing."
    def __init__(self, tok_func:Callable=SpacyTokenizer, lang:str='en', pre_rules:ListRules=None,
                 post_rules:ListRules=None, special_cases:Collection[str]=None, n_cpus:int=None):
        self.tok_func,self.lang,self.special_cases = tok_func,lang,special_cases
        self.pre_rules  = ifnone(pre_rules,  defaults.text_pre_rules )
        self.post_rules = ifnone(post_rules, defaults.text_post_rules)
        self.special_cases = special_cases if special_cases is not None else defaults.text_spec_tok
        self.n_cpus = ifnone(n_cpus, defaults.cpus)

    def __repr__(self) -> str:
        res = f'Tokenizer {self.tok_func.__name__} in {self.lang} with the following rules:\n'
        for rule in self.pre_rules: res += f' - {rule.__name__}\n'
        for rule in self.post_rules: res += f' - {rule.__name__}\n'
        return res

    def process_text(self, t:str, tok:BaseTokenizer) -> List[str]:
        "Process one text `t` with tokenizer `tok`."
        for rule in self.pre_rules: t = rule(t)
        toks = tok.tokenizer(t)
        for rule in self.post_rules: toks = rule(toks)
        return toks

    def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts` in one process."
        tok = self.tok_func(self.lang)
        if self.special_cases: tok.add_special_cases(self.special_cases)
        return [self.process_text(str(t), tok) for t in texts]

    def process_all(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts`."
        if self.n_cpus <= 1: return self._process_all_1(texts)
        with ProcessPoolExecutor(self.n_cpus) as e:
            return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])


class Vocab():
    "Contain the correspondence between numbers and tokens and numericalize."
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = defaultdict(int,{v:k for k,v in enumerate(self.itos)})

    def numericalize(self, t:Collection[str]) -> List[int]:
        "Convert a list of tokens `t` to their ids."
        return [self.stoi[w] for w in t]

    def textify(self, nums:Collection[int], sep=' ') -> List[str]:
        "Convert a list of `nums` to their tokens."
J
jrzaurin 已提交
187
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums] # type: ignore
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

    def __getstate__(self):
        return {'itos':self.itos}

    def __setstate__(self, state:dict):
        self.itos = state['itos']
        self.stoi = defaultdict(int,{v:k for k,v in enumerate(self.itos)})

    def save(self, path):
        "Save `self.itos` in `path`"
        pickle.dump(self.itos, open(path, 'wb'))

    @classmethod
    def create(cls, tokens:Tokens, max_vocab:int, min_freq:int) -> 'Vocab':
        "Create a vocabulary from a set of `tokens`."
        freq = Counter(p for o in tokens for p in o)
        itos = [o for o,c in freq.most_common(max_vocab) if c >= min_freq]
        for o in reversed(defaults.text_spec_tok):
            if o in itos: itos.remove(o)
            itos.insert(0, o)
        itos = itos[:max_vocab]
        if len(itos) < max_vocab: #Make sure vocab size is a multiple of 8 for fast mixed precision training
            while len(itos)%8 !=0: itos.append('xxfake')
        return cls(itos)

    @classmethod
    def load(cls, path):
        "Load the `Vocab` contained in `path`"
        itos = pickle.load(open(path, 'rb'))
        return cls(itos)