data_util.py 4.7 KB
Newer Older
Z
zhaopu7 已提交
1 2
# coding=utf-8
import collections
Z
zhaopu7 已提交
3
import os
Z
zhaopu7 已提交
4

Z
zhaopu 已提交
5

Z
zhaopu7 已提交
6
# -- function --
Z
zhaopu7 已提交
7

Z
zhaopu7 已提交
8 9 10 11 12 13
def save_vocab(word_id_dict, vocab_file_name):
    """
    save vocab.
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    :param vocab_file_name: vocab file name.
    """
Z
zhaopu 已提交
14 15
    f = open(vocab_file_name, 'w')
    for (k, v) in word_id_dict.items():
Z
zhaopu7 已提交
16
        f.write(k.encode('utf-8') + '\t' + str(v) + '\n')
Z
zhaopu 已提交
17
    print('save vocab to ' + vocab_file_name)
Z
zhaopu7 已提交
18 19
    f.close()

Z
zhaopu 已提交
20

Z
zhaopu7 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def load_vocab(vocab_file_name):
    """
    load vocab from file
    :param vocab_file_name: vocab file name.
    :return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    """
    if not os.path.isfile(vocab_file_name):
        raise Exception('vocab file does not exist!')
    dict = {}
    for line in open(vocab_file_name):
        if len(line) < 2:
            continue
        kv = line.decode('utf-8').strip().split('\t')
        dict[kv[0]] = int(kv[1])
    return dict

Z
zhaopu 已提交
37

Z
zhaopu7 已提交
38
def build_vocab(file_name, vocab_max_size):
Z
zhaopu7 已提交
39 40
    """
    build vacab.
Z
zhaopu7 已提交
41 42

    :param vocab_max_size: vocab's max size.
Z
zhaopu7 已提交
43 44 45
    :return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    """
    words = []
Z
zhaopu7 已提交
46
    for line in open(file_name):
Z
zhaopu 已提交
47
        words += line.decode('utf-8', 'ignore').strip().split()
Z
zhaopu7 已提交
48 49 50 51 52 53 54 55 56 57 58

    counter = collections.Counter(words)
    counter = sorted(counter.items(), key=lambda x: -x[1])
    if len(counter) > vocab_max_size:
        counter = counter[:vocab_max_size]
    words, counts = zip(*counter)
    word_id_dict = dict(zip(words, range(2, len(words) + 2)))
    word_id_dict['<UNK>'] = 0
    word_id_dict['<EOS>'] = 1
    return word_id_dict

Z
zhaopu 已提交
59

Z
zhaopu7 已提交
60
def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
Z
zhaopu7 已提交
61 62 63 64
    """
    create reader, each sample with fixed length.

    :param file_name: file name.
Z
zhaopu7 已提交
65
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
66 67 68
    :param sentence_len: each sample's length.
    :return: data reader.
    """
Z
zhaopu 已提交
69

Z
zhaopu7 已提交
70 71 72 73
    def reader():
        words = []
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
Z
zhaopu 已提交
74
            words += line.decode('utf-8', 'ignore').strip().split()
Z
zhaopu7 已提交
75 76
        ids = [word_id_dict.get(w, UNK) for w in words]
        words_len = len(words)
Z
zhaopu 已提交
77
        sentence_num = (words_len - 1) // sentence_len
Z
zhaopu7 已提交
78 79 80 81
        count = 0
        while count < sentence_num:
            start = count * sentence_len
            count += 1
Z
zhaopu 已提交
82 83
            yield ids[start:start + sentence_len], ids[start + 1:start + sentence_len + 1]

Z
zhaopu7 已提交
84 85
    return reader

Z
zhaopu 已提交
86

Z
zhaopu7 已提交
87
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict):
Z
zhaopu7 已提交
88 89 90 91
    """
    create reader, each line is a sample.

    :param file_name: file name.
Z
zhaopu7 已提交
92 93 94
    :param min_sentence_length: sentence's min length.
    :param max_sentence_length: sentence's max length.
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
95 96
    :return: data reader.
    """
Z
zhaopu 已提交
97

Z
zhaopu7 已提交
98 99 100
    def reader():
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
Z
zhaopu 已提交
101
            words = line.decode('utf-8', 'ignore').strip().split()
Z
zhaopu7 已提交
102 103 104 105 106 107 108
            if len(words) < min_sentence_length or len(words) > max_sentence_length:
                continue
            ids = [word_id_dict.get(w, UNK) for w in words]
            ids.append(word_id_dict['<EOS>'])
            target = ids[1:]
            target.append(word_id_dict['<EOS>'])
            yield ids[:], target[:]
Z
zhaopu 已提交
109

Z
zhaopu7 已提交
110 111
    return reader

Z
zhaopu 已提交
112

Z
zhaopu7 已提交
113
def _reader_creator_for_NGram(file_name, N, word_id_dict):
Z
zhaopu7 已提交
114 115 116 117 118
    """
    create reader for ngram.

    :param file_name: file name.
    :param N: ngram's n.
Z
zhaopu7 已提交
119
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
120 121 122
    :return: data reader.
    """
    assert N >= 2
Z
zhaopu 已提交
123

Z
zhaopu7 已提交
124 125 126 127
    def reader():
        words = []
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
Z
zhaopu 已提交
128
            words += line.decode('utf-8', 'ignore').strip().split()
Z
zhaopu7 已提交
129 130
        ids = [word_id_dict.get(w, UNK) for w in words]
        words_len = len(words)
Z
zhaopu 已提交
131 132 133
        for i in range(words_len - N - 1):
            yield tuple(ids[i:i + N])

Z
zhaopu7 已提交
134 135
    return reader

Z
zhaopu 已提交
136

Z
zhaopu7 已提交
137 138
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict):
    return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict)
Z
zhaopu7 已提交
139

Z
zhaopu 已提交
140

Z
zhaopu7 已提交
141 142
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict):
    return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict)
Z
zhaopu7 已提交
143

Z
zhaopu 已提交
144

Z
zhaopu7 已提交
145 146
def train_data_for_NGram(train_file, N, word_id_dict):
    return _reader_creator_for_NGram(train_file, N, word_id_dict)
Z
zhaopu7 已提交
147

Z
zhaopu 已提交
148

Z
zhaopu7 已提交
149 150
def test_data_for_NGram(test_file, N, word_id_dict):
    return _reader_creator_for_NGram(test_file, N, word_id_dict)