reader.py 2.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3
# -*- coding: utf-8 -*

import numpy as np
Q
Qiao Longfei 已提交
4
import preprocess
Q
Qiao Longfei 已提交
5 6


Q
Qiao Longfei 已提交
7 8 9 10 11 12
class Word2VecReader(object):
    def __init__(self, dict_path, data_path, window_size=5):
        self.window_size_ = window_size
        self.data_path_ = data_path
        self.word_to_id_ = dict()

13 14
        word_all_count = 0
        word_counts = []
Q
Qiao Longfei 已提交
15
        word_id = 0
16

Q
Qiao Longfei 已提交
17 18
        with open(dict_path, 'r') as f:
            for line in f:
19 20
                word, count = line.split()[0], int(line.split()[1])
                self.word_to_id_[word] = word_id
Q
Qiao Longfei 已提交
21
                word_id += 1
22 23 24
                word_counts.append(count)
                word_all_count += count

Q
Qiao Longfei 已提交
25
        self.dict_size = len(self.word_to_id_)
26 27
        self.word_frequencys = [ float(count)/word_all_count  for count in word_counts]
        print("dict_size = " + str(self.dict_size)) + " word_all_count = " + str(word_all_count)
Q
Qiao Longfei 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

    def get_context_words(self, words, idx, window_size):
        """
        Get the context word list of target word.

        words: the words of the current line
        idx: input word index
        window_size: window size
        """
        target_window = np.random.randint(1, window_size + 1)
        # need to keep in mind that maybe there are no enough words before the target word.
        start_point = idx - target_window if (idx - target_window) > 0 else 0
        end_point = idx + target_window
        # context words of the target word
        targets = set(words[start_point:idx] + words[idx + 1:end_point + 1])
        return list(targets)

    def train(self):
        def _reader():
            with open(self.data_path_, 'r') as f:
                for line in f:
Q
Qiao Longfei 已提交
49
                    line = preprocess.text_strip(line)
Q
Qiao Longfei 已提交
50 51 52 53 54 55 56 57 58 59 60
                    word_ids = [
                        self.word_to_id_[word] for word in line.split()
                        if word in self.word_to_id_
                    ]
                    for idx, target_id in enumerate(word_ids):
                        context_word_ids = self.get_context_words(
                            word_ids, idx, self.window_size_)
                        for context_id in context_word_ids:
                            yield [target_id], [context_id]

        return _reader
Q
Qiao Longfei 已提交
61 62 63


if __name__ == "__main__":
Q
Qiao Longfei 已提交
64 65 66
    epochs = 10
    batch_size = 1000
    window_size = 10
Q
Qiao Longfei 已提交
67

Q
Qiao Longfei 已提交
68
    reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size)
Q
Qiao Longfei 已提交
69
    i = 0
Q
Qiao Longfei 已提交
70
    for x, y in reader.train()():
Q
Qiao Longfei 已提交
71 72 73 74 75 76
        print("x: " + str(x))
        print("y: " + str(y))
        print("\n")
        if i == 10:
            exit(0)
        i += 1