test_word2vec.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
import unittest

19 20
import numpy as np

21
import paddle
22
from paddle import fluid
H
hjyp 已提交
23
from paddle.jit.api import to_static
24
from paddle.nn import Embedding
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51


def fake_text():
    corpus = []
    for i in range(100):
        line = "i love paddlepaddle"
        corpus.append(line)
    return corpus


corpus = fake_text()


def data_preprocess(corpus):
    new_corpus = []
    for line in corpus:
        line = line.strip().lower()
        line = line.split(" ")
        new_corpus.append(line)

    return new_corpus


corpus = data_preprocess(corpus)


def build_dict(corpus, min_freq=3):
52
    word_freq_dict = {}
53 54 55 56 57 58
    for line in corpus:
        for word in line:
            if word not in word_freq_dict:
                word_freq_dict[word] = 0
            word_freq_dict[word] += 1

59 60 61
    word_freq_dict = sorted(
        word_freq_dict.items(), key=lambda x: x[1], reverse=True
    )
62

63 64 65
    word2id_dict = {}
    word2id_freq = {}
    id2word_dict = {}
66

67
    word2id_freq[0] = 1.0
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    word2id_dict['[oov]'] = 0
    id2word_dict[0] = '[oov]'

    for word, freq in word_freq_dict:

        if freq < min_freq:
            word2id_freq[0] += freq
            continue

        curr_id = len(word2id_dict)
        word2id_dict[word] = curr_id
        word2id_freq[word2id_dict[word]] = freq
        id2word_dict[curr_id] = word

    return word2id_freq, word2id_dict, id2word_dict


word2id_freq, word2id_dict, id2word_dict = build_dict(corpus)
vocab_size = len(word2id_freq)
print("there are totoally %d different words in the corpus" % vocab_size)
for _, (word, word_id) in zip(range(50), word2id_dict.items()):
89 90 91 92
    print(
        "word %s, its id %d, its word freq %d"
        % (word, word_id, word2id_freq[word_id])
    )
93 94 95 96 97 98 99


def convert_corpus_to_id(corpus, word2id_dict):
    new_corpus = []
    for line in corpus:
        new_line = [
            word2id_dict[word]
100 101 102
            if word in word2id_dict
            else word2id_dict['[oov]']
            for word in line
103 104 105 106 107 108 109 110 111 112
        ]
        new_corpus.append(new_line)
    return new_corpus


corpus = convert_corpus_to_id(corpus, word2id_dict)


def subsampling(corpus, word2id_freq):
    def keep(word_id):
113
        return random.uniform(0, 1) < math.sqrt(
114 115
            1e-4 / word2id_freq[word_id] * len(corpus)
        )
116 117 118 119 120 121 122 123 124 125 126

    new_corpus = []
    for line in corpus:
        new_line = [word for word in line if keep(word)]
        new_corpus.append(line)
    return new_corpus


corpus = subsampling(corpus, word2id_freq)


127 128 129 130 131 132 133
def build_data(
    corpus,
    word2id_dict,
    word2id_freq,
    max_window_size=3,
    negative_sample_num=10,
):
134 135 136 137 138 139 140 141

    dataset = []

    for line in corpus:
        for center_word_idx in range(len(line)):
            window_size = random.randint(1, max_window_size)
            center_word = line[center_word_idx]

142 143 144 145
            positive_word_range = (
                max(0, center_word_idx - window_size),
                min(len(line) - 1, center_word_idx + window_size),
            )
146
            positive_word_candidates = [
147 148 149 150
                line[idx]
                for idx in range(
                    positive_word_range[0], positive_word_range[1] + 1
                )
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
                if idx != center_word_idx and line[idx] != line[center_word_idx]
            ]

            if not positive_word_candidates:
                continue

            for positive_word in positive_word_candidates:
                dataset.append((center_word, positive_word, 1))

            i = 0
            while i < negative_sample_num:
                negative_word_candidate = random.randint(0, vocab_size - 1)

                if negative_word_candidate not in positive_word_candidates:
                    dataset.append((center_word, negative_word_candidate, 0))
                    i += 1

    return dataset


dataset = build_data(corpus, word2id_dict, word2id_freq)
for _, (center_word, target_word, label) in zip(range(50), dataset):
173 174 175 176
    print(
        "center_word %s, target %s, label %d"
        % (id2word_dict[center_word], id2word_dict[target_word], label)
    )
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198


def build_batch(dataset, batch_size, epoch_num):

    center_word_batch = []
    target_word_batch = []
    label_batch = []
    eval_word_batch = []

    for epoch in range(epoch_num):
        for center_word, target_word, label in dataset:
            center_word_batch.append([center_word])
            target_word_batch.append([target_word])
            label_batch.append([label])

            if len(eval_word_batch) < 5:
                eval_word_batch.append([random.randint(0, 99)])
            elif len(eval_word_batch) < 10:
                eval_word_batch.append([random.randint(0, vocab_size - 1)])

            if len(center_word_batch) == batch_size:
                yield np.array(center_word_batch).astype("int64"), np.array(
199 200 201 202 203 204 205 206
                    target_word_batch
                ).astype("int64"), np.array(label_batch).astype(
                    "float32"
                ), np.array(
                    eval_word_batch
                ).astype(
                    "int64"
                )
207 208 209 210 211 212 213
                center_word_batch = []
                target_word_batch = []
                label_batch = []
                eval_word_batch = []

    if len(center_word_batch) > 0:
        yield np.array(center_word_batch).astype("int64"), np.array(
214 215 216 217 218 219
            target_word_batch
        ).astype("int64"), np.array(label_batch).astype("float32"), np.array(
            eval_word_batch
        ).astype(
            "int64"
        )
220 221


222
class SkipGram(paddle.nn.Layer):
223
    def __init__(self, name_scope, vocab_size, embedding_size, init_scale=0.1):
224
        super().__init__(name_scope)
225 226 227 228
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        self.embedding = Embedding(
229 230 231
            self.vocab_size,
            self.embedding_size,
            weight_attr=fluid.ParamAttr(
232
                name='embedding_para',
233
                initializer=paddle.nn.initializer.Uniform(
234
                    low=-0.5 / self.embedding_size,
235 236 237 238
                    high=0.5 / self.embedding_size,
                ),
            ),
        )
239 240

        self.embedding_out = Embedding(
241 242 243
            self.vocab_size,
            self.embedding_size,
            weight_attr=fluid.ParamAttr(
244
                name='embedding_out_para',
245
                initializer=paddle.nn.initializer.Uniform(
246
                    low=-0.5 / self.embedding_size,
247 248 249 250
                    high=0.5 / self.embedding_size,
                ),
            ),
        )
251

H
hjyp 已提交
252
    @to_static
253 254 255 256 257 258
    def forward(self, center_words, target_words, label):
        center_words_emb = self.embedding(center_words)
        target_words_emb = self.embedding_out(target_words)

        # center_words_emb = [batch_size, embedding_size]
        # target_words_emb = [batch_size, embedding_size]
259
        word_sim = paddle.multiply(center_words_emb, target_words_emb)
260
        word_sim = paddle.sum(word_sim, axis=-1)
261

262
        pred = paddle.nn.functional.sigmoid(word_sim)
263

264 265 266
        loss = paddle.nn.functional.binary_cross_entropy_with_logits(
            word_sim, label
        )
267
        loss = paddle.mean(loss)
268 269 270 271 272 273 274 275 276 277 278 279

        return pred, loss


batch_size = 512
epoch_num = 1
embedding_size = 200
learning_rate = 1e-3
total_steps = len(dataset) * epoch_num // batch_size


def train(to_static):
R
Ryan 已提交
280
    paddle.jit.enable_to_static(to_static)
281 282 283 284

    random.seed(0)
    np.random.seed(0)

285 286 287 288 289
    place = (
        fluid.CUDAPlace(0)
        if fluid.is_compiled_with_cuda()
        else fluid.CPUPlace()
    )
290 291 292 293
    with fluid.dygraph.guard(place):
        fluid.default_startup_program().random_seed = 1000
        fluid.default_main_program().random_seed = 1000

294 295 296
        skip_gram_model = SkipGram(
            "skip_gram_model", vocab_size, embedding_size
        )
297 298
        adam = fluid.optimizer.AdamOptimizer(
            learning_rate=learning_rate,
299 300
            parameter_list=skip_gram_model.parameters(),
        )
301 302 303 304

        step = 0
        ret = []
        for center_words, target_words, label, eval_words in build_batch(
305 306
            dataset, batch_size, epoch_num
        ):
307 308 309
            center_words_var = fluid.dygraph.to_variable(center_words)
            target_words_var = fluid.dygraph.to_variable(target_words)
            label_var = fluid.dygraph.to_variable(label)
310 311 312
            pred, loss = skip_gram_model(
                center_words_var, target_words_var, label_var
            )
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328

            loss.backward()
            adam.minimize(loss)
            skip_gram_model.clear_gradients()

            step += 1
            mean_loss = np.mean(loss.numpy())
            print("step %d / %d, loss %f" % (step, total_steps, mean_loss))
            ret.append(mean_loss)
        return np.array(ret)


class TestWord2Vec(unittest.TestCase):
    def test_dygraph_static_same_loss(self):
        dygraph_loss = train(to_static=False)
        static_loss = train(to_static=True)
329
        np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-05)
330 331 332


if __name__ == '__main__':
333
    unittest.main()