make_pretrain_data.py 4.5 KB
Newer Older
M
Meiyim 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
import sys
import argparse
import struct
#import jieba
import random as r
import re
import gzip
import logging
from itertools import accumulate
from functools import reduce, partial, wraps
from propeller import log
from propeller.paddle.data import feature_pb2, example_pb2
#jfrom data_util import RawtextColumn

import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')

def gen_segs(segment_piece):
    if len(segment_piece) == 0:
        return []
    else:
        return [min(segment_piece)] * len(segment_piece)

whit_space_pat = re.compile(r'\S+')
def segment(inputs, inputs_segment):
    ret = [r.span() for r in whit_space_pat.finditer(inputs)]
    ret = [(inputs[s: e], gen_segs(inputs_segment[s: e])) for i, (s, e) in enumerate(ret)]
    return ret


def tokenize(sen, seg_info):
    """
    char tokenizer (wordpiece english)
    normed txt(space seperated or not) => list of word-piece
    """
    sen = sen.lower()
    res_word, res_segments = [], []
    for match in pat.finditer(sen):
        words, pos = _wordpiece(match.group(0), vocab=vocab_set, unk_token='[UNK]')
        start_of_word = match.span()[0]
        for w, p in zip(words, pos):
            res_word.append(w)
            res_segments.append(gen_segs(seg_info[p[0] + start_of_word: p[1] + start_of_word]))
    return res_word, res_segments


def parse_txt(line):
    if len(line) == 0:
        return []
    line = line.decode('utf8')
    ret_line, ret_seginfo = [], []

    for l, i in segment(line, list(range(len(line)))):
        for ll, ii in zip(*tokenize(l, i)):
            ret_line.append(ll)
            ret_seginfo.append(ii)

    if args.check and r.random() < 0.005:
        print('****', file=sys.stderr)
        print(line, file=sys.stderr)
        print('|'.join(ret_line), file=sys.stderr)
        print(ret_seginfo, file=sys.stderr)
        print('****', file=sys.stderr)

    ret_line = [vocab.get(r, vocab['[UNK]']) for r in ret_line]
    ret_seginfo = [[-1] if i == [] else i for i in ret_seginfo] #for sentence piece only
    ret_seginfo = [min(i) for i in ret_seginfo]
    return ret_line, ret_seginfo


def build_example(slots):
    txt, seginfo = slots
    txt_fe_list = feature_pb2.FeatureList(feature=[feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=t)) for t in txt])
    segsinfo_fe_list = feature_pb2.FeatureList(feature=[feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=s)) for s in seginfo])
    assert len(txt_fe_list.feature) == len(segsinfo_fe_list.feature), 'txt[%d] and seginfo[%d] size not match' % (len(txt_fe_list.feature), len(segsinfo_fe_list.feature))
    features = {
        'txt':  txt_fe_list,
        'segs': segsinfo_fe_list,
    }

    ex = example_pb2.SequenceExample(feature_lists=feature_pb2.FeatureLists(feature_list=features))
    return ex


def write_gz(serialized, to_file):
    l = len(serialized)
    packed_data = struct.pack('i%ds' % l, l, serialized)
    to_file.write(packed_data)


def build_bb(from_file, to_file):
    slots = []
    for i, line in enumerate(from_file):
        line = line.strip()
        if args.verbose and i % 10000 == 0:
            log.debug(i)
        if len(line) == 0:
            if len(slots) != 0:
                transposed_slots = list(zip(*slots))
                ex = build_example(transposed_slots)
                write_gz(ex.SerializeToString(), to_file)
                slots = []
            continue
        parsed_line = parse_txt(line)
        slots.append(parsed_line)

    if len(slots) != 0:
        transposed_slots = list(zip(*slots))
        ex = build_example(transposed_slots)
        write_gz(ex.SerializeToString(), to_file)
        slots = []


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Pretrain Data Maker')
    parser.add_argument('src', type=str)
    parser.add_argument('tgt', type=str)
    parser.add_argument('--vocab', type=str, required=True)
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('-c', '--check', action='store_true')

    args = parser.parse_args()
    log.setLevel(logging.DEBUG)


M
Meiyim 已提交
127
    from ernie.tokenizing_ernie import _wordpiece
M
Meiyim 已提交
128 129 130 131 132 133 134 135 136 137
    pat = re.compile(r'([a-zA-Z0-9]+|\S)')

    vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab, 'rb'))}
    vocab_set = set(vocab.keys())


    with open(args.src, 'rb') as from_file, gzip.open(args.tgt, 'wb') as to_file:
        log.info('making gz from bb %s ==> %s' % (from_file, to_file))
        build_bb(from_file, to_file)
        log.info('done: %s' % to_file)