reader.py 5.2 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3
# -*- coding: utf-8 -*

import numpy as np
Q
Qiao Longfei 已提交
4
import preprocess
Q
Qiao Longfei 已提交
5

J
JiabinYang 已提交
6 7 8 9 10 11
import logging

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

Q
Qiao Longfei 已提交
12

Q
Qiao Longfei 已提交
13
class Word2VecReader(object):
14
    def __init__(self, dict_path, data_path, filelist, window_size=5):
Q
Qiao Longfei 已提交
15 16
        self.window_size_ = window_size
        self.data_path_ = data_path
17
        self.filelist = filelist
18
        self.num_non_leaf = 0
Q
Qiao Longfei 已提交
19
        self.word_to_id_ = dict()
20 21 22
        self.id_to_word = dict()
        self.word_to_path = dict()
        self.word_to_code = dict()
Q
Qiao Longfei 已提交
23

24 25
        word_all_count = 0
        word_counts = []
Q
Qiao Longfei 已提交
26
        word_id = 0
27

Q
Qiao Longfei 已提交
28 29
        with open(dict_path, 'r') as f:
            for line in f:
30 31
                word, count = line.split()[0], int(line.split()[1])
                self.word_to_id_[word] = word_id
32
                self.id_to_word[word_id] = word  #build id to word dict
Q
Qiao Longfei 已提交
33
                word_id += 1
34 35 36
                word_counts.append(count)
                word_all_count += count

37 38 39 40
        with open(dict_path + "_word_to_id_", 'w+') as f6:
            for k, v in self.word_to_id_.items():
                f6.write(str(k) + " " + str(v) + '\n')

Q
Qiao Longfei 已提交
41
        self.dict_size = len(self.word_to_id_)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        self.word_frequencys = [
            float(count) / word_all_count for count in word_counts
        ]
        print("dict_size = " + str(
            self.dict_size)) + " word_all_count = " + str(word_all_count)

        with open(dict_path + "_ptable", 'r') as f2:
            for line in f2:
                self.word_to_path[line.split(":")[0]] = np.fromstring(
                    line.split(':')[1], dtype=int, sep=' ')
                self.num_non_leaf = np.fromstring(
                    line.split(':')[1], dtype=int, sep=' ')[0]
        print("word_ptable dict_size = " + str(len(self.word_to_path)))

        with open(dict_path + "_pcode", 'r') as f3:
            for line in f3:
                self.word_to_code[line.split(":")[0]] = np.fromstring(
                    line.split(':')[1], dtype=int, sep=' ')
        print("word_pcode dict_size = " + str(len(self.word_to_code)))
Q
Qiao Longfei 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

    def get_context_words(self, words, idx, window_size):
        """
        Get the context word list of target word.

        words: the words of the current line
        idx: input word index
        window_size: window size
        """
        target_window = np.random.randint(1, window_size + 1)
        # need to keep in mind that maybe there are no enough words before the target word.
        start_point = idx - target_window if (idx - target_window) > 0 else 0
        end_point = idx + target_window
        # context words of the target word
        targets = set(words[start_point:idx] + words[idx + 1:end_point + 1])
        return list(targets)

78
    def train(self, with_hs):
Q
Qiao Longfei 已提交
79
        def _reader():
80 81
            for file in self.filelist:
                with open(self.data_path_ + "/" + file, 'r') as f:
J
JiabinYang 已提交
82 83
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
84 85 86 87 88 89 90 91 92 93 94
                    for line in f:
                        line = preprocess.text_strip(line)
                        word_ids = [
                            self.word_to_id_[word] for word in line.split()
                            if word in self.word_to_id_
                        ]
                        for idx, target_id in enumerate(word_ids):
                            context_word_ids = self.get_context_words(
                                word_ids, idx, self.window_size_)
                            for context_id in context_word_ids:
                                yield [target_id], [context_id]
Q
Qiao Longfei 已提交
95

96
        def _reader_hs():
97 98
            for file in self.filelist:
                with open(self.data_path_ + "/" + file, 'r') as f:
J
JiabinYang 已提交
99 100
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
                    for line in f:
                        line = preprocess.text_strip(line)
                        word_ids = [
                            self.word_to_id_[word] for word in line.split()
                            if word in self.word_to_id_
                        ]
                        for idx, target_id in enumerate(word_ids):
                            context_word_ids = self.get_context_words(
                                word_ids, idx, self.window_size_)
                            for context_id in context_word_ids:
                                yield [target_id], [context_id], [
                                    self.word_to_code[self.id_to_word[
                                        context_id]]
                                ], [
                                    self.word_to_path[self.id_to_word[
                                        context_id]]
                                ]
118 119 120 121 122

        if not with_hs:
            return _reader
        else:
            return _reader_hs
Q
Qiao Longfei 已提交
123 124 125


if __name__ == "__main__":
Q
Qiao Longfei 已提交
126
    window_size = 10
Q
Qiao Longfei 已提交
127

Q
Qiao Longfei 已提交
128
    reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size)
Q
Qiao Longfei 已提交
129
    i = 0
Q
Qiao Longfei 已提交
130
    for x, y in reader.train()():
Q
Qiao Longfei 已提交
131 132 133 134 135 136
        print("x: " + str(x))
        print("y: " + str(y))
        print("\n")
        if i == 10:
            exit(0)
        i += 1