preprocess.py 6.2 KB
Newer Older
Q
Qiao Longfei 已提交
1
# -*- coding: utf-8 -*
Z
zhangwenhui03 已提交
2 3
import os
import random
Q
Qiao Longfei 已提交
4
import re
5
import six
Q
Qiao Longfei 已提交
6
import argparse
J
JiabinYang 已提交
7
import io
Z
zhangwenhui03 已提交
8
import math
Z
zhang wenhui 已提交
9 10 11 12
import sys
if six.PY2:
    reload(sys)
    sys.setdefaultencoding('utf-8')
13 14
prog = re.compile("[^a-z ]", flags=0)

Q
Qiao Longfei 已提交
15 16 17 18 19

def parse_args():
    parser = argparse.ArgumentParser(
        description="Paddle Fluid word2 vector preprocess")
    parser.add_argument(
Z
zhangwenhui03 已提交
20 21 22 23 24
        '--build_dict_corpus_dir', type=str, help="The dir of corpus")
    parser.add_argument(
        '--input_corpus_dir', type=str, help="The dir of input corpus")
    parser.add_argument(
        '--output_corpus_dir', type=str, help="The dir of output corpus")
Q
Qiao Longfei 已提交
25 26 27 28
    parser.add_argument(
        '--dict_path',
        type=str,
        default='./dict',
Z
zhangwenhui03 已提交
29
        help="The path of dictionary ")
Q
Qiao Longfei 已提交
30
    parser.add_argument(
Z
zhangwenhui03 已提交
31
        '--min_count',
Q
Qiao Longfei 已提交
32 33
        type=int,
        default=5,
Z
zhangwenhui03 已提交
34 35 36 37 38 39 40
        help="If the word count is less then min_count, it will be removed from dict"
    )
    parser.add_argument(
        '--downsample',
        type=float,
        default=0.001,
        help="filter word by downsample")
41
    parser.add_argument(
Z
zhangwenhui03 已提交
42
        '--filter_corpus',
43 44
        action='store_true',
        default=False,
Z
zhangwenhui03 已提交
45
        help='Filter corpus')
46
    parser.add_argument(
Z
zhangwenhui03 已提交
47 48 49 50
        '--build_dict',
        action='store_true',
        default=False,
        help='Build dict from corpus')
Q
Qiao Longfei 已提交
51 52 53
    return parser.parse_args()


Q
Qiao Longfei 已提交
54
def text_strip(text):
Z
zhangwenhui03 已提交
55 56
    #English Preprocess Rule
    return prog.sub("", text.lower())
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85


# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
    if _is_unicode(s):
        return s
    try:
        return _to_unicode(s)
    except UnicodeDecodeError:
        res = _to_unicode(s, ignore_errors=True)
        return res


def _is_unicode(s):
    if six.PY2:
        if isinstance(s, unicode):
            return True
    else:
        if isinstance(s, str):
            return True
    return False


def _to_unicode(s, ignore_errors=False):
    if _is_unicode(s):
        return s
    error_mode = "ignore" if ignore_errors else "strict"
    return s.decode("utf-8", errors=error_mode)
Q
Qiao Longfei 已提交
86 87


Z
zhangwenhui03 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
def filter_corpus(args):
    """
    filter corpus and convert id.
    """
    word_count = dict()
    word_to_id_ = dict()
    word_all_count = 0
    id_counts = []
    word_id = 0
    #read dict
    with io.open(args.dict_path, 'r', encoding='utf-8') as f:
        for line in f:
            word, count = line.split()[0], int(line.split()[1])
            word_count[word] = count
            word_to_id_[word] = word_id
            word_id += 1
            id_counts.append(count)
            word_all_count += count

Z
zhangwenhui03 已提交
107
    #write word2id file
Z
zhangwenhui03 已提交
108
    print("write word2id file to : " + args.dict_path + "_word_to_id_")
Z
zhangwenhui03 已提交
109 110 111 112
    with io.open(
            args.dict_path + "_word_to_id_", 'w+', encoding='utf-8') as fid:
        for k, v in word_to_id_.items():
            fid.write(k + " " + str(v) + '\n')
Z
zhangwenhui03 已提交
113 114 115 116
    #filter corpus and convert id
    if not os.path.exists(args.output_corpus_dir):
        os.makedirs(args.output_corpus_dir)
    for file in os.listdir(args.input_corpus_dir):
Z
zhang wenhui 已提交
117 118 119 120
        with io.open(
                os.path.join(args.output_corpus_dir, 'convert_' + file),
                "w",
                encoding='utf-8') as wf:
Z
zhangwenhui03 已提交
121
            with io.open(
Z
zhang wenhui 已提交
122 123 124
                    os.path.join(args.input_corpus_dir, file),
                    encoding='utf-8') as rf:
                print(os.path.join(args.input_corpus_dir, file))
Z
zhangwenhui03 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
                for line in rf:
                    signal = False
                    line = text_strip(line)
                    words = line.split()
                    for item in words:
                        if item in word_count:
                            idx = word_to_id_[item]
                        else:
                            idx = word_to_id_[native_to_unicode('<UNK>')]
                        count_w = id_counts[idx]
                        corpus_size = word_all_count
                        keep_prob = (
                            math.sqrt(count_w /
                                      (args.downsample * corpus_size)) + 1
                        ) * (args.downsample * corpus_size) / count_w
                        r_value = random.random()
                        if r_value > keep_prob:
                            continue
                        wf.write(_to_unicode(str(idx) + " "))
                        signal = True
                    if signal:
                        wf.write(_to_unicode("\n"))


def build_dict(args):
Q
Qiao Longfei 已提交
150
    """
Q
Qiao Longfei 已提交
151
    proprocess the data, generate dictionary and save into dict_path.
Z
zhangwenhui03 已提交
152
    :param corpus_dir: the input data dir.
Q
Qiao Longfei 已提交
153
    :param dict_path: the generated dict path. the data in dict is "word count"
Z
zhangwenhui03 已提交
154
    :param min_count:
Q
Qiao Longfei 已提交
155
    :return:
Q
Qiao Longfei 已提交
156
    """
Q
Qiao Longfei 已提交
157 158
    # word to count

Z
zhangwenhui03 已提交
159
    word_count = dict()
160

Z
zhangwenhui03 已提交
161
    for file in os.listdir(args.build_dict_corpus_dir):
J
JiabinYang 已提交
162
        with io.open(
Z
zhangwenhui03 已提交
163 164
                args.build_dict_corpus_dir + "/" + file, encoding='utf-8') as f:
            print("build dict : ", args.build_dict_corpus_dir + "/" + file)
J
JiabinYang 已提交
165
            for line in f:
Z
zhangwenhui03 已提交
166 167 168 169 170 171 172 173
                line = text_strip(line)
                words = line.split()
                for item in words:
                    if item in word_count:
                        word_count[item] = word_count[item] + 1
                    else:
                        word_count[item] = 1

Q
Qiao Longfei 已提交
174 175
    item_to_remove = []
    for item in word_count:
Z
zhangwenhui03 已提交
176
        if word_count[item] <= args.min_count:
Q
Qiao Longfei 已提交
177
            item_to_remove.append(item)
Z
zhangwenhui03 已提交
178 179

    unk_sum = 0
Q
Qiao Longfei 已提交
180
    for item in item_to_remove:
Z
zhangwenhui03 已提交
181
        unk_sum += word_count[item]
Q
Qiao Longfei 已提交
182
        del word_count[item]
Z
zhangwenhui03 已提交
183 184 185 186
    #sort by count
    word_count[native_to_unicode('<UNK>')] = unk_sum
    word_count = sorted(
        word_count.items(), key=lambda word_count: -word_count[1])
187

J
JiabinYang 已提交
188
    with io.open(args.dict_path, 'w+', encoding='utf-8') as f:
Z
zhangwenhui03 已提交
189
        for k, v in word_count:
J
JiabinYang 已提交
190
            f.write(k + " " + str(v) + '\n')
Q
Qiao Longfei 已提交
191 192 193


if __name__ == "__main__":
Z
zhangwenhui03 已提交
194 195 196 197 198 199 200 201
    args = parse_args()
    if args.build_dict:
        build_dict(args)
    elif args.filter_corpus:
        filter_corpus(args)
    else:
        print(
            "error command line, please choose --build_dict or --filter_corpus")