reader.py 7.4 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3
# -*- coding: utf-8 -*

import numpy as np
Q
Qiao Longfei 已提交
4
import preprocess
J
JiabinYang 已提交
5
import logging
J
JiabinYang 已提交
6
import io
J
JiabinYang 已提交
7 8 9 10 11

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

Q
Qiao Longfei 已提交
12

J
JiabinYang 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
class NumpyRandomInt(object):
    def __init__(self, a, b, buf_size=1000):
        self.idx = 0
        self.buffer = np.random.random_integers(a, b, buf_size)
        self.a = a
        self.b = b

    def __call__(self):
        if self.idx == len(self.buffer):
            self.buffer = np.random.random_integers(self.a, self.b,
                                                    len(self.buffer))
            self.idx = 0

        result = self.buffer[self.idx]
        self.idx += 1
        return result


Q
Qiao Longfei 已提交
31
class Word2VecReader(object):
32 33 34 35 36 37 38
    def __init__(self,
                 dict_path,
                 data_path,
                 filelist,
                 trainer_id,
                 trainer_num,
                 window_size=5):
Q
Qiao Longfei 已提交
39 40
        self.window_size_ = window_size
        self.data_path_ = data_path
41
        self.filelist = filelist
42
        self.num_non_leaf = 0
Q
Qiao Longfei 已提交
43
        self.word_to_id_ = dict()
44
        self.id_to_word = dict()
J
JiabinYang 已提交
45
        self.word_count = dict()
46 47
        self.word_to_path = dict()
        self.word_to_code = dict()
48 49
        self.trainer_id = trainer_id
        self.trainer_num = trainer_num
Q
Qiao Longfei 已提交
50

51 52
        word_all_count = 0
        word_counts = []
Q
Qiao Longfei 已提交
53
        word_id = 0
54

J
JiabinYang 已提交
55
        with io.open(dict_path, 'r', encoding='utf-8') as f:
Q
Qiao Longfei 已提交
56
            for line in f:
57
                word, count = line.split()[0], int(line.split()[1])
J
JiabinYang 已提交
58
                self.word_count[word] = count
59
                self.word_to_id_[word] = word_id
60
                self.id_to_word[word_id] = word  #build id to word dict
Q
Qiao Longfei 已提交
61
                word_id += 1
62 63 64
                word_counts.append(count)
                word_all_count += count

J
JiabinYang 已提交
65
        with io.open(dict_path + "_word_to_id_", 'w+', encoding='utf-8') as f6:
66
            for k, v in self.word_to_id_.items():
J
JiabinYang 已提交
67
                f6.write(k + " " + str(v) + '\n')
68

Q
Qiao Longfei 已提交
69
        self.dict_size = len(self.word_to_id_)
70 71 72
        self.word_frequencys = [
            float(count) / word_all_count for count in word_counts
        ]
73 74
        print("dict_size = " + str(self.dict_size) + " word_all_count = " + str(
            word_all_count))
75

J
JiabinYang 已提交
76
        with io.open(dict_path + "_ptable", 'r', encoding='utf-8') as f2:
77
            for line in f2:
J
JiabinYang 已提交
78
                self.word_to_path[line.split('\t')[0]] = np.fromstring(
79
                    line.split('\t')[1], dtype=int, sep=' ')
80
                self.num_non_leaf = np.fromstring(
81
                    line.split('\t')[1], dtype=int, sep=' ')[0]
82 83
        print("word_ptable dict_size = " + str(len(self.word_to_path)))

J
JiabinYang 已提交
84
        with io.open(dict_path + "_pcode", 'r', encoding='utf-8') as f3:
85
            for line in f3:
J
JiabinYang 已提交
86
                self.word_to_code[line.split('\t')[0]] = np.fromstring(
87
                    line.split('\t')[1], dtype=int, sep=' ')
88
        print("word_pcode dict_size = " + str(len(self.word_to_code)))
J
JiabinYang 已提交
89
        self.random_generator = NumpyRandomInt(1, self.window_size_ + 1)
Q
Qiao Longfei 已提交
90

J
JiabinYang 已提交
91
    def get_context_words(self, words, idx):
Q
Qiao Longfei 已提交
92 93 94 95 96 97 98
        """
        Get the context word list of target word.

        words: the words of the current line
        idx: input word index
        window_size: window size
        """
J
JiabinYang 已提交
99 100 101 102
        target_window = self.random_generator()
        start_point = idx - target_window  # if (idx - target_window) > 0 else 0
        if start_point < 0:
            start_point = 0
Q
Qiao Longfei 已提交
103
        end_point = idx + target_window
J
JiabinYang 已提交
104 105 106
        targets = words[start_point:idx] + words[idx + 1:end_point + 1]

        return set(targets)
Q
Qiao Longfei 已提交
107

108
    def train(self, with_hs, with_other_dict):
Q
Qiao Longfei 已提交
109
        def _reader():
110
            for file in self.filelist:
J
JiabinYang 已提交
111 112 113
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
J
JiabinYang 已提交
114 115
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
116
                    count = 1
117
                    for line in f:
118
                        if self.trainer_id == count % self.trainer_num:
119 120 121 122 123
                            if with_other_dict:
                                line = preprocess.strip_lines(line,
                                                              self.word_count)
                            else:
                                line = preprocess.text_strip(line)
124 125 126 127 128 129
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
J
JiabinYang 已提交
130
                                    word_ids, idx)
131 132 133 134 135
                                for context_id in context_word_ids:
                                    yield [target_id], [context_id]
                        else:
                            pass
                        count += 1
Q
Qiao Longfei 已提交
136

137
        def _reader_hs():
138
            for file in self.filelist:
J
JiabinYang 已提交
139 140 141
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
J
JiabinYang 已提交
142 143
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
144
                    count = 1
145
                    for line in f:
146
                        if self.trainer_id == count % self.trainer_num:
147 148 149 150 151
                            if with_other_dict:
                                line = preprocess.strip_lines(line,
                                                              self.word_count)
                            else:
                                line = preprocess.text_strip(line)
152 153 154 155 156 157
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
J
JiabinYang 已提交
158
                                    word_ids, idx)
159 160
                                for context_id in context_word_ids:
                                    yield [target_id], [context_id], [
J
JiabinYang 已提交
161
                                        self.word_to_path[self.id_to_word[
J
JiabinYang 已提交
162
                                            target_id]]
163
                                    ], [
J
JiabinYang 已提交
164
                                        self.word_to_code[self.id_to_word[
J
JiabinYang 已提交
165
                                            target_id]]
166 167 168 169
                                    ]
                        else:
                            pass
                        count += 1
170 171 172 173 174

        if not with_hs:
            return _reader
        else:
            return _reader_hs
Q
Qiao Longfei 已提交
175 176 177


if __name__ == "__main__":
J
JiabinYang 已提交
178 179 180 181 182 183
    window_size = 5

    reader = Word2VecReader(
        "./data/1-billion_dict",
        "./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/",
        ["news.en-00001-of-00100"], 0, 1)
Q
Qiao Longfei 已提交
184 185

    i = 0
J
JiabinYang 已提交
186 187
    # print(reader.train(True))
    for x, y, z, f in reader.train(True)():
Q
Qiao Longfei 已提交
188 189
        print("x: " + str(x))
        print("y: " + str(y))
J
JiabinYang 已提交
190 191
        print("path: " + str(z))
        print("code: " + str(f))
Q
Qiao Longfei 已提交
192 193 194 195
        print("\n")
        if i == 10:
            exit(0)
        i += 1