reader.py 4.0 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3
# -*- coding: utf-8 -*

import numpy as np
Q
Qiao Longfei 已提交
4
import preprocess
J
JiabinYang 已提交
5
import logging
Z
zhangwenhui03 已提交
6 7
import math
import random
J
JiabinYang 已提交
8
import io
J
JiabinYang 已提交
9 10 11 12 13

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

Q
Qiao Longfei 已提交
14

J
JiabinYang 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
class NumpyRandomInt(object):
    def __init__(self, a, b, buf_size=1000):
        self.idx = 0
        self.buffer = np.random.random_integers(a, b, buf_size)
        self.a = a
        self.b = b

    def __call__(self):
        if self.idx == len(self.buffer):
            self.buffer = np.random.random_integers(self.a, self.b,
                                                    len(self.buffer))
            self.idx = 0

        result = self.buffer[self.idx]
        self.idx += 1
        return result


Q
Qiao Longfei 已提交
33
class Word2VecReader(object):
34 35 36 37 38 39 40
    def __init__(self,
                 dict_path,
                 data_path,
                 filelist,
                 trainer_id,
                 trainer_num,
                 window_size=5):
Q
Qiao Longfei 已提交
41 42
        self.window_size_ = window_size
        self.data_path_ = data_path
43
        self.filelist = filelist
Q
Qiao Longfei 已提交
44
        self.word_to_id_ = dict()
45
        self.id_to_word = dict()
J
JiabinYang 已提交
46
        self.word_count = dict()
47 48
        self.trainer_id = trainer_id
        self.trainer_num = trainer_num
Q
Qiao Longfei 已提交
49

50
        word_all_count = 0
Z
zhangwenhui03 已提交
51
        id_counts = []
Q
Qiao Longfei 已提交
52
        word_id = 0
53

J
JiabinYang 已提交
54
        with io.open(dict_path, 'r', encoding='utf-8') as f:
Q
Qiao Longfei 已提交
55
            for line in f:
56
                word, count = line.split()[0], int(line.split()[1])
J
JiabinYang 已提交
57
                self.word_count[word] = count
58
                self.word_to_id_[word] = word_id
59
                self.id_to_word[word_id] = word  #build id to word dict
Q
Qiao Longfei 已提交
60
                word_id += 1
Z
zhangwenhui03 已提交
61
                id_counts.append(count)
62 63
                word_all_count += count

Z
zhangwenhui03 已提交
64 65 66 67 68 69
        self.word_all_count = word_all_count
        self.corpus_size_ = word_all_count
        self.dict_size = len(self.word_to_id_)
        self.id_counts_ = id_counts
        #write word2id file
        print("write word2id file to : " + dict_path + "_word_to_id_")
J
JiabinYang 已提交
70
        with io.open(dict_path + "_word_to_id_", 'w+', encoding='utf-8') as f6:
71
            for k, v in self.word_to_id_.items():
J
JiabinYang 已提交
72
                f6.write(k + " " + str(v) + '\n')
73

Z
zhangwenhui03 已提交
74 75 76
        print("corpus_size:", self.corpus_size_)
        self.id_frequencys = [
            float(count) / word_all_count for count in self.id_counts_
77
        ]
Z
zhangwenhui03 已提交
78 79 80
        print("dict_size = " + str(
            self.dict_size)) + " word_all_count = " + str(word_all_count)

J
JiabinYang 已提交
81
        self.random_generator = NumpyRandomInt(1, self.window_size_ + 1)
Q
Qiao Longfei 已提交
82

J
JiabinYang 已提交
83
    def get_context_words(self, words, idx):
Q
Qiao Longfei 已提交
84 85 86 87 88 89
        """
        Get the context word list of target word.
        words: the words of the current line
        idx: input word index
        window_size: window size
        """
J
JiabinYang 已提交
90 91 92 93
        target_window = self.random_generator()
        start_point = idx - target_window  # if (idx - target_window) > 0 else 0
        if start_point < 0:
            start_point = 0
Q
Qiao Longfei 已提交
94
        end_point = idx + target_window
J
JiabinYang 已提交
95
        targets = words[start_point:idx] + words[idx + 1:end_point + 1]
Z
zhangwenhui03 已提交
96
        return targets
J
JiabinYang 已提交
97

Z
zhangwenhui03 已提交
98 99
    def train(self):
        def nce_reader():
100
            for file in self.filelist:
J
JiabinYang 已提交
101 102 103
                with io.open(
                        self.data_path_ + "/" + file, 'r',
                        encoding='utf-8') as f:
J
JiabinYang 已提交
104 105
                    logger.info("running data in {}".format(self.data_path_ +
                                                            "/" + file))
106
                    count = 1
107
                    for line in f:
108
                        if self.trainer_id == count % self.trainer_num:
Z
zhangwenhui03 已提交
109
                            word_ids = [int(w) for w in line.split()]
110 111
                            for idx, target_id in enumerate(word_ids):
                                context_word_ids = self.get_context_words(
J
JiabinYang 已提交
112
                                    word_ids, idx)
113 114 115
                                for context_id in context_word_ids:
                                    yield [target_id], [context_id]
                        count += 1
116

Z
zhangwenhui03 已提交
117
        return nce_reader