preprocess.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
# -*- coding: utf-8 -*-
import os
import io
import re
import json
import click
import collections


def build_vocabulary(dataset, cutoff=0):
    dictionary = collections.defaultdict(int)
    for data in dataset:
        for sent in data[2]:
            for char in sent:
                dictionary[char] += 1
    dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items())
    dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0]))
    vocab, _ = list(zip(*dictionary))
    return (u"<unk>", u"<s>", u"<e>") + vocab


@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw data")
@click.option("--outfile", type=str, help="Path to save the training data")
@click.option("--dictfile", type=str, help="Path to save the dictionary file")
def preprocess(datadir, outfile, dictfile):
    dataset = []
    note_pattern1 = re.compile(u"(.*?)", re.U)
    note_pattern2 = re.compile(u"〖.*?〗", re.U)
    note_pattern3 = re.compile(u"-.*?-。?", re.U)
    note_pattern4 = re.compile(u"(.*$", re.U)
    note_pattern5 = re.compile(u"。。.*)$", re.U)
    note_pattern6 = re.compile(u"。。", re.U)
    note_pattern7 = re.compile(u"[《》「」\[\]]", re.U)
    print("Loading raw data...")
    for fn in os.listdir(datadir):
        with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f:
            for data in json.load(f):
                title = data['title']
                author = data['author']
                p = "".join(data['paragraphs'])
                p = "".join(p.split())
                p = note_pattern1.sub(u"", p)
                p = note_pattern2.sub(u"", p)
                p = note_pattern3.sub(u"", p)
                p = note_pattern4.sub(u"", p)
                p = note_pattern5.sub(u"。", p)
                p = note_pattern6.sub(u"。", p)
                p = note_pattern7.sub(u"", p)
                if (p == u"" or u"{" in p or u"}" in p or u"{" in p or
                        u"}" in p or u"、" in p or u":" in p or u";" in p or
                        u"!" in p or u"?" in p or u"●" in p or u"□" in p or
                        u"囗" in p or u")" in p):
                    continue
                paragraphs = p.split(u"。")
                paragraphs = filter(lambda x: len(x), paragraphs)
                if len(paragraphs) > 1:
                    dataset.append((title, author, paragraphs))
    print("Finished...")

    print("Constructing vocabularies...")
    vocab = build_vocabulary(dataset, cutoff=10)
    with io.open(dictfile, "w", encoding="utf8") as f:
        for v in vocab:
            f.write(v + "\n")
    print("Finished...")

    print("Writing processed data...")
    with io.open(outfile, "w", encoding="utf8") as f:
        for data in dataset:
            title = data[0]
            author = data[1]
            paragraphs = ".".join(data[2])
            f.write("\t".join((title, author, paragraphs)) + "\n")
    print("Finished...")


if __name__ == "__main__":
    preprocess()