preprocess.py 1.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# -*- coding: utf-8 -*

import re
import argparse


def parse_args():
    parser = argparse.ArgumentParser(
        description="Paddle Fluid word2 vector preprocess")
    parser.add_argument(
        '--data_path',
        type=str,
        required=True,
        help="The path of training dataset")
    parser.add_argument(
        '--dict_path',
        type=str,
        default='./dict',
        help="The path of generated dict")
    parser.add_argument(
        '--freq',
        type=int,
        default=5,
        help="If the word count is less then freq, it will be removed from dict")

    return parser.parse_args()


def preprocess(data_path, dict_path, freq):
Q
Qiao Longfei 已提交
30
    """
Q
Qiao Longfei 已提交
31 32 33 34 35
    proprocess the data, generate dictionary and save into dict_path.
    :param data_path: the input data path.
    :param dict_path: the generated dict path. the data in dict is "word count"
    :param freq:
    :return:
Q
Qiao Longfei 已提交
36
    """
Q
Qiao Longfei 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    # word to count
    word_count = dict()

    with open(data_path) as f:
        for line in f:
            line = line.lower()
            line = re.sub("[^a-z ]", "", line)
            words = line.split()
            for item in words:
                if item in word_count:
                    word_count[item] = word_count[item] + 1
                else:
                    word_count[item] = 1
    item_to_remove = []
    for item in word_count:
        if word_count[item] <= freq:
            item_to_remove.append(item)
    for item in item_to_remove:
        del word_count[item]

    with open(dict_path, 'w+') as f:
        for k, v in word_count.items():
            f.write(str(k) + " " + str(v) + '\n')
Q
Qiao Longfei 已提交
60 61 62


if __name__ == "__main__":
Q
Qiao Longfei 已提交
63 64
    args = parse_args()
    preprocess(args.data_path, args.dict_path, args.freq)