utils.py 2.0 KB
Newer Older
P
peterzhang2029 已提交
1
import os
2
import logging
P
peterzhang2029 已提交
3 4 5 6 7 8
from collections import defaultdict

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


P
peterzhang2029 已提交
9
def build_word_dict(data_dir, save_path, use_col=1, cutoff_fre=1):
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col:
                    continue
                doc = line_splits[use_col]
                for sent in doc.strip().split("."):
                    for w in sent.split():
                        values[w] += 1

    values['<unk>'] = cutoff_fre
    with open(save_path, "w") as f:
        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            if count < cutoff_fre:
                break
            f.write("%s\t%d\n" % (v, count))

P
peterzhang2029 已提交
34

P
peterzhang2029 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
def build_label_dict(data_dir, save_path, use_col=0):
    values = defaultdict(int)

    for file_name in os.listdir(data_dir):
        file_path = os.path.join(data_dir, file_name)
        if not os.path.isfile(file_path):
            continue
        with open(file_path, "r") as fdata:
            for line in fdata:
                line_splits = line.strip().split("\t")
                if len(line_splits) < use_col:
                    continue
                values[line_splits[use_col]] += 1

    with open(save_path, "w") as f:
        for v, count in sorted(
                values.iteritems(), key=lambda x: x[1], reverse=True):
            f.write("%s\t%d\n" % (v, count))


55 56 57
def load_dict(dict_path):
    return dict((line.strip().split("\t")[0], idx)
                for idx, line in enumerate(open(dict_path, "r").readlines()))
P
peterzhang2029 已提交
58 59 60 61 62


def load_reverse_dict(dict_path):
    return dict((idx, line.strip().split("\t")[0])
                for idx, line in enumerate(open(dict_path, "r").readlines()))