reader.py 11.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3
# -*- coding: utf-8 -*

import numpy as np
Q
Qiao Longfei 已提交
4
import preprocess
J
JiabinYang 已提交
5
import logging
J
JiabinYang 已提交
6
import io
J
JiabinYang 已提交
7 8 9 10 11

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

Q
Qiao Longfei 已提交
12

J
JiabinYang 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
class NumpyRandomInt(object):
    def __init__(self, a, b, buf_size=1000):
        self.idx = 0
        self.buffer = np.random.random_integers(a, b, buf_size)
        self.a = a
        self.b = b

    def __call__(self):
        if self.idx == len(self.buffer):
            self.buffer = np.random.random_integers(self.a, self.b,
                                                    len(self.buffer))
            self.idx = 0

        result = self.buffer[self.idx]
        self.idx += 1
        return result


Q
Qiao Longfei 已提交
31
class Word2VecReader(object):
32 33 34 35 36 37 38
    def __init__(self,
                 dict_path,
                 data_path,
                 filelist,
                 trainer_id,
                 trainer_num,
                 window_size=5):
Q
Qiao Longfei 已提交
39 40
        self.window_size_ = window_size
        self.data_path_ = data_path
41
        self.filelist = filelist
42
        self.num_non_leaf = 0
Q
Qiao Longfei 已提交
43
        self.word_to_id_ = dict()
44
        self.id_to_word = dict()
J
JiabinYang 已提交
45
        self.word_count = dict()
46 47
        self.word_to_path = dict()
        self.word_to_code = dict()
48 49
        self.trainer_id = trainer_id
        self.trainer_num = trainer_num
Q
Qiao Longfei 已提交
50

51 52
        word_all_count = 0
        word_counts = []
Q
Qiao Longfei 已提交
53
        word_id = 0
54

J
JiabinYang 已提交
55
        with io.open(dict_path, 'r', encoding='utf-8') as f:
Q
Qiao Longfei 已提交
56
            for line in f:
57
                word, count = line.split()[0], int(line.split()[1])
J
JiabinYang 已提交
58
                self.word_count[word] = count
59
                self.word_to_id_[word] = word_id
60
                self.id_to_word[word_id] = word  #build id to word dict
Q
Qiao Longfei 已提交
61
                word_id += 1
62 63 64
                word_counts.append(count)
                word_all_count += count

J
JiabinYang 已提交
65
        with io.open(dict_path + "_word_to_id_", 'w+', encoding='utf-8') as f6:
66
            for k, v in self.word_to_id_.items():
J
JiabinYang 已提交
67
                f6.write(k + " " + str(v) + '\n')
68

Q
Qiao Longfei 已提交
69
        self.dict_size = len(self.word_to_id_)
70 71 72 73 74 75
        self.word_frequencys = [
            float(count) / word_all_count for count in word_counts
        ]
        print("dict_size = " + str(
            self.dict_size)) + " word_all_count = " + str(word_all_count)

J
JiabinYang 已提交
76
        with io.open(dict_path + "_ptable", 'r', encoding='utf-8') as f2:
77
            for line in f2:
J
JiabinYang 已提交
78
                self.word_to_path[line.split('\t')[0]] = np.fromstring(
79
                    line.split('\t')[1], dtype=int, sep=' ')
80
                self.num_non_leaf = np.fromstring(
81
                    line.split('\t')[1], dtype=int, sep=' ')[0]
82 83
        print("word_ptable dict_size = " + str(len(self.word_to_path)))

J
JiabinYang 已提交
84
        with io.open(dict_path + "_pcode", 'r', encoding='utf-8') as f3:
85
            for line in f3:
J
JiabinYang 已提交
86
                self.word_to_code[line.split('\t')[0]] = np.fromstring(
87
                    line.split('\t')[1], dtype=int, sep=' ')
88
        print("word_pcode dict_size = " + str(len(self.word_to_code)))
J
JiabinYang 已提交
89
        self.random_generator = NumpyRandomInt(1, self.window_size_ + 1)
Q
Qiao Longfei 已提交
90

J
JiabinYang 已提交
91
    def get_context_words(self, words, idx):
Q
Qiao Longfei 已提交
92 93 94 95 96 97 98
        """
        Get the context word list of target word.

        words: the words of the current line
        idx: input word index
        window_size: window size
        """
J
JiabinYang 已提交
99 100 101 102
        target_window = self.random_generator()
        start_point = idx - target_window  # if (idx - target_window) > 0 else 0
        if start_point < 0:
            start_point = 0
Q
Qiao Longfei 已提交
103
        end_point = idx + target_window
J
JiabinYang 已提交
104 105 106
        targets = words[start_point:idx] + words[idx + 1:end_point + 1]

        return set(targets)
Q
Qiao Longfei 已提交
107

108
    def train(self, with_hs):
Q
Qiao Longfei 已提交
109
        def _reader():
110
            for file in self.filelist:
J
JiabinYang 已提交
111 112 113
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
J
JiabinYang 已提交
114 115
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
116
                    count = 1
117
                    for line in f:
118
                        if self.trainer_id == count % self.trainer_num:
J
JiabinYang 已提交
119
                            line = preprocess.strip_lines(line, self.word_count)
120 121 122 123 124 125
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
J
JiabinYang 已提交
126
                                    word_ids, idx)
127 128 129 130 131
                                for context_id in context_word_ids:
                                    yield [target_id], [context_id]
                        else:
                            pass
                        count += 1
Q
Qiao Longfei 已提交
132

133
        def _reader_hs():
134
            for file in self.filelist:
J
JiabinYang 已提交
135 136 137
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
J
JiabinYang 已提交
138 139
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
140
                    count = 1
141
                    for line in f:
142
                        if self.trainer_id == count % self.trainer_num:
J
JiabinYang 已提交
143
                            line = preprocess.strip_lines(line, self.word_count)
144 145 146 147 148 149
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
J
JiabinYang 已提交
150
                                    word_ids, idx)
151 152
                                for context_id in context_word_ids:
                                    yield [target_id], [context_id], [
J
JiabinYang 已提交
153
                                        self.word_to_path[self.id_to_word[
J
JiabinYang 已提交
154
                                            target_id]]
155
                                    ], [
J
JiabinYang 已提交
156
                                        self.word_to_code[self.id_to_word[
J
JiabinYang 已提交
157
                                            target_id]]
158 159 160 161
                                    ]
                        else:
                            pass
                        count += 1
162 163 164 165 166

        if not with_hs:
            return _reader
        else:
            return _reader_hs
Q
Qiao Longfei 已提交
167

J
JiabinYang 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    def async_train(self, with_hs):
        def _reader():
            write_f = list()
            for i in range(20):
                write_f.append(
                    io.open(
                        "./async_data/async_" + str(i), 'w+', encoding='utf-8'))
            for file in self.filelist:
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
                    count = 1
                    file_spilt_count = 0
                    for line in f:
                        if self.trainer_id == count % self.trainer_num:
                            line = preprocess.strip_lines(line, self.word_count)
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
                                    word_ids, idx)
                                for context_id in context_word_ids:
                                    content = "1" + " " + str(
                                        target_id) + " " + "1" + " " + str(
                                            context_id) + '\n'
                                    write_f[file_spilt_count %
                                            20].write(content.decode('utf-8'))
                                    file_spilt_count += 1
                        else:
                            pass
                        count += 1
            for i in range(20):
                write_f[i].close()

        def _reader_hs():
            write_f = list()
            for i in range(20):
                write_f.append(
                    io.open(
                        "./async_data/async_" + str(i), 'w+', encoding='utf-8'))

            for file in self.filelist:
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
                    count = 1
                    file_spilt_count = 0
                    for line in f:
                        if self.trainer_id == count % self.trainer_num:
                            line = preprocess.strip_lines(line, self.word_count)
                            word_ids = [
                                self.word_to_id_[word] for word in line.split()
                                if word in self.word_to_id_
                            ]
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
                                    word_ids, idx)
                                for context_id in context_word_ids:
                                    path = [
                                        str(i)
                                        for i in self.word_to_path[
                                            self.id_to_word[target_id]]
                                    ]
                                    code = [
                                        str(j)
                                        for j in self.word_to_code[
                                            self.id_to_word[target_id]]
                                    ]
                                    content = str(1) + " " + str(
                                        target_id
                                    ) + " " + str(1) + " " + str(
                                        context_id
                                    ) + " " + str(len(path)) + " " + ' '.join(
                                        path) + " " + str(len(
                                            code)) + " " + ' '.join(code) + '\n'
                                    write_f[file_spilt_count %
                                            20].write(content.decode('utf-8'))
                                    file_spilt_count += 1
                        else:
                            pass
                        count += 1
            for i in range(20):
                write_f[i].close()

        if not with_hs:
            _reader()
        else:
            _reader_hs()

Q
Qiao Longfei 已提交
263 264

if __name__ == "__main__":
J
JiabinYang 已提交
265 266 267 268 269 270
    window_size = 5

    reader = Word2VecReader(
        "./data/1-billion_dict",
        "./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/",
        ["news.en-00001-of-00100"], 0, 1)
Q
Qiao Longfei 已提交
271 272

    i = 0
J
JiabinYang 已提交
273 274
    # print(reader.train(True))
    for x, y, z, f in reader.train(True)():
Q
Qiao Longfei 已提交
275 276
        print("x: " + str(x))
        print("y: " + str(y))
J
JiabinYang 已提交
277 278
        print("path: " + str(z))
        print("code: " + str(f))
Q
Qiao Longfei 已提交
279 280 281 282
        print("\n")
        if i == 10:
            exit(0)
        i += 1