utils.py 3.2 KB
Newer Older
P
peterzhang2029 已提交
1
import os
2
import logging
P
peterzhang2029 已提交
3 4 5 6 7 8
from collections import defaultdict

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


P
peterzhang2029 已提交
9
def build_word_dict(data_dir, save_path, use_col=1, cutoff_fre=1):
P
peterzhang2029 已提交
10 11 12 13 14 15 16 17 18 19 20 21
    """
    Build word dictionary from training data.
    :param data_dir: The directory of training dataset.
    :type data_dir: str
    :params save_path: The path where the word dictionary will be saved.
    :type save_path: str
    :params use_col: The index of text juring line split.
    :type use_col: int
    :params cutoff_fre: The word will not be added to dictionary if it's
                    frequency is less than cutoff_fre.
    :type cutoff_fre: int
    """
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col:
                    continue
                doc = line_splits[use_col]
                for sent in doc.strip().split("."):
                    for w in sent.split():
                        values[w] += 1

    values['<unk>'] = cutoff_fre
    with open(save_path, "w") as f:
        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            if count < cutoff_fre:
                break
            f.write("%s\t%d\n" % (v, count))

P
peterzhang2029 已提交
46

P
peterzhang2029 已提交
47
def build_label_dict(data_dir, save_path, use_col=0):
P
peterzhang2029 已提交
48 49 50 51 52 53 54 55 56
    """
    Build label dictionary from training data.
    :param data_dir: The directory of training dataset.
    :type data_dir: str
    :params save_path: The path where the label dictionary will be saved.
    :type save_path: str
    :params use_col: The index of label juring line split.
    :type use_col: int
    """
P
peterzhang2029 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col:
                    continue
                values[line_splits[use_col]] += 1

    with open(save_path, "w") as f:
        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            f.write("%s\t%d\n" % (v, count))


76
def load_dict(dict_path):
P
peterzhang2029 已提交
77 78 79 80 81
    """
    Load word dictionary from dictionary path.
    :param dict_path: The path of word dictionary.
    :type data_dir: str
    """
82 83
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))
P
peterzhang2029 已提交
84 85 86


def load_reverse_dict(dict_path):
P
peterzhang2029 已提交
87 88 89 90 91 92 93
    """
    Load the reversed word dictionary from dictionary path.
    Index of each word is saved in key of the dictionary and the
    corresponding word saved in value of the dictionary.
    :param dict_path: The path of word dictionary.
    :type data_dir: str
    """
P
peterzhang2029 已提交
94 95
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))