utils.py 3.0 KB
Newer Older
Z
zhaopu 已提交
1
import os
C
caoying03 已提交
2 3
import logging
from collections import defaultdict
Z
zhaopu 已提交
4

C
caoying03 已提交
5
__all__ = ["build_dict", "load_dict"]
Z
zhaopu 已提交
6

C
caoying03 已提交
7 8
logger = logging.getLogger("paddle")
logger.setLevel(logging.DEBUG)
Z
zhaopu 已提交
9 10


C
caoying03 已提交
11 12 13 14 15
def build_dict(data_file,
               save_path,
               max_word_num,
               cutoff_word_fre=5,
               insert_extra_words=["<unk>", "<e>"]):
Z
zhaopu 已提交
16
    """
C
caoying03 已提交
17
    :param data_file: path of data file
18
    :type data_file: str
C
caoying03 已提交
19
    :param save_path: path to save the word dictionary
20
    :type save_path: str
C
caoying03 已提交
21 22
    :param vocab_max_size: if vocab_max_size is set, top vocab_max_size words
        will be added into word vocabulary
23
    :type vocab_max_size: int
C
caoying03 已提交
24
    :param cutoff_thd: if cutoff_thd is set, words whose frequencies are less
25
        than cutoff_thd will not be added into word vocabulary.
C
caoying03 已提交
26
        NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time
27
    :type cutoff_word_fre: int
C
caoying03 已提交
28
    :param extra_keys: extra keys defined by users that added into the word
29 30
        dictionary, ususally these keys include <unk>, start and ending marks
    :type extra_keys: list
Z
zhaopu 已提交
31
    """
C
caoying03 已提交
32 33 34 35 36 37 38 39
    word_count = defaultdict(int)
    with open(data_file, "r") as f:
        for idx, line in enumerate(f):
            if not (idx + 1) % 100000:
                logger.debug("processing %d lines ... " % (idx + 1))
            words = line.strip().lower().split()
            for w in words:
                word_count[w] += 1
Z
zhaopu 已提交
40

C
caoying03 已提交
41 42
    sorted_words = sorted(
        word_count.iteritems(), key=lambda x: x[1], reverse=True)
Z
zhaopu 已提交
43

C
caoying03 已提交
44 45 46
    stop_pos = len(sorted_words) if sorted_words[-1][
        1] > cutoff_word_fre else next(idx for idx, v in enumerate(sorted_words)
                                       if v[1] < cutoff_word_fre)
Z
zhaopu 已提交
47

C
caoying03 已提交
48 49 50 51 52 53 54
    stop_pos = min(max_word_num, stop_pos)
    with open(save_path, "w") as fdict:
        for w in insert_extra_words:
            fdict.write("%s\t-1\n" % (w))
        for idx, info in enumerate(sorted_words):
            if idx == stop_pos: break
            fdict.write("%s\t%d\n" % (info[0], info[-1]))
Z
zhaopu 已提交
55 56


C
caoying03 已提交
57
def load_dict(dict_path):
Z
zhaopu 已提交
58
    """
59 60 61 62
    load word dictionary from the given file. Each line of the give file is
    a word in the word dictionary. The first column of the line, seperated by
    TAB, is the key, while the line index is the value.

C
caoying03 已提交
63
    :param dict_path: path of word dictionary
64 65 66
    :type dict_path: str
    :return: the dictionary
    :rtype: dict
Z
zhaopu 已提交
67
    """
C
caoying03 已提交
68 69 70
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))

Z
zhaopu 已提交
71

C
caoying03 已提交
72
def load_reverse_dict(dict_path):
73 74 75 76 77 78 79 80 81 82
    """
    load word dictionary from the given file. Each line of the give file is
    a word in the word dictionary. The line index is the key, while the first
    column of the line, seperated by TAB, is the value.

    :param dict_path: path of word dictionary
    :type dict_path: str
    :return: the dictionary
    :rtype: dict
    """
C
caoying03 已提交
83 84
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))