data_util.py 4.7 KB
Newer Older
Z
zhaopu7 已提交
1 2
# coding=utf-8
import collections
Z
zhaopu7 已提交
3
import os
Z
zhaopu7 已提交
4

Z
zhaopu7 已提交
5
# -- function --
Z
zhaopu7 已提交
6

Z
zhaopu7 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
def save_vocab(word_id_dict, vocab_file_name):
    """
    save vocab.
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    :param vocab_file_name: vocab file name.
    """
    f = open(vocab_file_name,'w')
    for(k, v) in word_id_dict.items():
        f.write(k.encode('utf-8') + '\t' + str(v) + '\n')
    print('save vocab to '+vocab_file_name)
    f.close()

def load_vocab(vocab_file_name):
    """
    load vocab from file
    :param vocab_file_name: vocab file name.
    :return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    """
    if not os.path.isfile(vocab_file_name):
        raise Exception('vocab file does not exist!')
    dict = {}
    for line in open(vocab_file_name):
        if len(line) < 2:
            continue
        kv = line.decode('utf-8').strip().split('\t')
        dict[kv[0]] = int(kv[1])
    return dict

def build_vocab(file_name, vocab_max_size):
Z
zhaopu7 已提交
36 37
    """
    build vacab.
Z
zhaopu7 已提交
38 39

    :param vocab_max_size: vocab's max size.
Z
zhaopu7 已提交
40 41 42
    :return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
    """
    words = []
Z
zhaopu7 已提交
43
    for line in open(file_name):
Z
zhaopu7 已提交
44 45 46 47 48 49 50 51 52 53 54 55
        words += line.decode('utf-8','ignore').strip().split()

    counter = collections.Counter(words)
    counter = sorted(counter.items(), key=lambda x: -x[1])
    if len(counter) > vocab_max_size:
        counter = counter[:vocab_max_size]
    words, counts = zip(*counter)
    word_id_dict = dict(zip(words, range(2, len(words) + 2)))
    word_id_dict['<UNK>'] = 0
    word_id_dict['<EOS>'] = 1
    return word_id_dict

Z
zhaopu7 已提交
56
def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
Z
zhaopu7 已提交
57 58 59 60
    """
    create reader, each sample with fixed length.

    :param file_name: file name.
Z
zhaopu7 已提交
61
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    :param sentence_len: each sample's length.
    :return: data reader.
    """
    def reader():
        words = []
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
            words += line.decode('utf-8','ignore').strip().split()
        ids = [word_id_dict.get(w, UNK) for w in words]
        words_len = len(words)
        sentence_num = (words_len-1) // sentence_len
        count = 0
        while count < sentence_num:
            start = count * sentence_len
            count += 1
            yield ids[start:start+sentence_len], ids[start+1:start+sentence_len+1]
    return reader

Z
zhaopu7 已提交
80
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict):
Z
zhaopu7 已提交
81 82 83 84
    """
    create reader, each line is a sample.

    :param file_name: file name.
Z
zhaopu7 已提交
85 86 87
    :param min_sentence_length: sentence's min length.
    :param max_sentence_length: sentence's max length.
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    :return: data reader.
    """
    def reader():
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
            words = line.decode('utf-8','ignore').strip().split()
            if len(words) < min_sentence_length or len(words) > max_sentence_length:
                continue
            ids = [word_id_dict.get(w, UNK) for w in words]
            ids.append(word_id_dict['<EOS>'])
            target = ids[1:]
            target.append(word_id_dict['<EOS>'])
            yield ids[:], target[:]
    return reader

Z
zhaopu7 已提交
103
def _reader_creator_for_NGram(file_name, N, word_id_dict):
Z
zhaopu7 已提交
104 105 106 107 108
    """
    create reader for ngram.

    :param file_name: file name.
    :param N: ngram's n.
Z
zhaopu7 已提交
109
    :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
Z
zhaopu7 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123
    :return: data reader.
    """
    assert N >= 2
    def reader():
        words = []
        UNK = word_id_dict['<UNK>']
        for line in open(file_name):
            words += line.decode('utf-8','ignore').strip().split()
        ids = [word_id_dict.get(w, UNK) for w in words]
        words_len = len(words)
        for i in range(words_len-N-1):
            yield tuple(ids[i:i+N])
    return reader

Z
zhaopu7 已提交
124 125
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict):
    return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict)
Z
zhaopu7 已提交
126

Z
zhaopu7 已提交
127 128
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict):
    return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict)
Z
zhaopu7 已提交
129

Z
zhaopu7 已提交
130 131
def train_data_for_NGram(train_file, N, word_id_dict):
    return _reader_creator_for_NGram(train_file, N, word_id_dict)
Z
zhaopu7 已提交
132

Z
zhaopu7 已提交
133 134
def test_data_for_NGram(test_file, N, word_id_dict):
    return _reader_creator_for_NGram(test_file, N, word_id_dict)