net.py 5.3 KB
Newer Older
S
slf12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
neural network for word2vec
"""
from __future__ import print_function
B
Bai Yifan 已提交
18
import paddle
Z
zhouzj 已提交
19
import paddle.nn.functional as F
S
slf12 已提交
20 21


Z
zhouzj 已提交
22 23 24 25 26
def skip_gram_word2vec(dict_size,
                       embedding_size,
                       batch_size,
                       is_sparse=False,
                       neg_num=5):
S
slf12 已提交
27

Z
zhouzj 已提交
28
    words = []
B
Bai Yifan 已提交
29 30 31 32 33 34
    input_word = paddle.static.data(
        name="input_word", shape=[None, 1], dtype='int64')
    true_word = paddle.static.data(
        name='true_label', shape=[None, 1], dtype='int64')
    neg_word = paddle.static.data(
        name="neg_label", shape=[None, neg_num], dtype='int64')
S
slf12 已提交
35

Z
zhouzj 已提交
36 37 38
    words.append(input_word)
    words.append(true_word)
    words.append(neg_word)
S
slf12 已提交
39

Z
zhouzj 已提交
40 41
    py_reader = paddle.io.DataLoader.from_generator(
        capacity=64, feed_list=words, use_double_buffer=True, iterable=False)
S
slf12 已提交
42

B
Bai Yifan 已提交
43 44
    words[0] = paddle.reshape(words[0], [-1])
    words[1] = paddle.reshape(words[1], [-1])
S
slf12 已提交
45
    init_width = 0.5 / embedding_size
B
Bai Yifan 已提交
46
    input_emb = paddle.static.nn.embedding(
S
slf12 已提交
47 48 49
        input=words[0],
        is_sparse=is_sparse,
        size=[dict_size, embedding_size],
B
Bai Yifan 已提交
50
        param_attr=paddle.ParamAttr(
S
slf12 已提交
51
            name='emb',
B
Bai Yifan 已提交
52
            initializer=paddle.nn.initializer.Uniform(-init_width, init_width)))
S
slf12 已提交
53

B
Bai Yifan 已提交
54
    true_emb_w = paddle.static.nn.embedding(
S
slf12 已提交
55 56 57
        input=words[1],
        is_sparse=is_sparse,
        size=[dict_size, embedding_size],
B
Bai Yifan 已提交
58 59 60
        param_attr=paddle.ParamAttr(
            name='emb_w',
            initializer=paddle.nn.initializer.Constant(value=0.0)))
S
slf12 已提交
61

B
Bai Yifan 已提交
62
    true_emb_b = paddle.static.nn.embedding(
S
slf12 已提交
63 64 65
        input=words[1],
        is_sparse=is_sparse,
        size=[dict_size, 1],
B
Bai Yifan 已提交
66 67 68 69
        param_attr=paddle.ParamAttr(
            name='emb_b',
            initializer=paddle.nn.initializer.Constant(value=0.0)))
    neg_word_reshape = paddle.reshape(words[2], shape=[-1])
S
slf12 已提交
70 71
    neg_word_reshape.stop_gradient = True

B
Bai Yifan 已提交
72
    neg_emb_w = paddle.static.nn.embedding(
S
slf12 已提交
73 74 75
        input=neg_word_reshape,
        is_sparse=is_sparse,
        size=[dict_size, embedding_size],
Z
zhouzj 已提交
76
        param_attr=paddle.ParamAttr(name='emb_w', learning_rate=1.0))
S
slf12 已提交
77

B
Bai Yifan 已提交
78
    neg_emb_w_re = paddle.reshape(
S
slf12 已提交
79
        neg_emb_w, shape=[-1, neg_num, embedding_size])
B
Bai Yifan 已提交
80
    neg_emb_b = paddle.static.nn.embedding(
S
slf12 已提交
81 82 83
        input=neg_word_reshape,
        is_sparse=is_sparse,
        size=[dict_size, 1],
Z
zhouzj 已提交
84
        param_attr=paddle.ParamAttr(name='emb_b', learning_rate=1.0))
S
slf12 已提交
85

B
Bai Yifan 已提交
86
    neg_emb_b_vec = paddle.reshape(neg_emb_b, shape=[-1, neg_num])
Z
zhouzj 已提交
87 88 89
    true_logits = paddle.add(
        paddle.mean(paddle.multiply(input_emb, true_emb_w), keepdim=True),
        true_emb_b)
B
Bai Yifan 已提交
90
    input_emb_re = paddle.reshape(input_emb, shape=[-1, 1, embedding_size])
W
whs 已提交
91
    neg_matmul = paddle.matmul(input_emb_re, neg_emb_w_re, transpose_y=True)
B
Bai Yifan 已提交
92 93
    neg_matmul_re = paddle.reshape(neg_matmul, shape=[-1, neg_num])
    neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec)
S
slf12 已提交
94 95
    #nce loss

Z
zhouzj 已提交
96 97 98 99
    label_ones = paddle.full(
        shape=[batch_size, 1], fill_value=1.0, dtype='float32')
    label_zeros = paddle.full(
        shape=[batch_size, neg_num], fill_value=0.0, dtype='float32')
S
slf12 已提交
100

Z
zhouzj 已提交
101 102 103 104 105 106
    true_xent = F.binary_cross_entropy_with_logits(
        true_logits, label_ones, reduction='none')
    neg_xent = F.binary_cross_entropy_with_logits(
        neg_logits, label_zeros, reduction='none')
    cost = paddle.add(
        paddle.sum(true_xent, axis=1), paddle.sum(neg_xent, axis=1))
B
Bai Yifan 已提交
107
    avg_cost = paddle.mean(cost)
S
slf12 已提交
108 109 110 111
    return avg_cost, py_reader


def infer_network(vocab_size, emb_size):
B
Bai Yifan 已提交
112 113 114 115 116 117 118 119 120 121
    analogy_a = paddle.static.data(
        name="analogy_a", shape=[None, 1], dtype='int64')
    analogy_b = paddle.static.data(
        name="analogy_b", shape=[None, 1], dtype='int64')
    analogy_c = paddle.static.data(
        name="analogy_c", shape=[None, 1], dtype='int64')
    all_label = paddle.static.data(
        name="all_label", shape=[vocab_size, 1], dtype='int64')
    all_label = paddle.reshape(all_label, [-1])
    emb_all_label = paddle.static.nn.embedding(
S
slf12 已提交
122 123
        input=all_label, size=[vocab_size, emb_size], param_attr="emb")

B
Bai Yifan 已提交
124 125
    analogy_a = paddle.reshape(analogy_a, [-1])
    emb_a = paddle.static.nn.embedding(
S
slf12 已提交
126
        input=analogy_a, size=[vocab_size, emb_size], param_attr="emb")
B
Bai Yifan 已提交
127 128
    analogy_b = paddle.reshape(analogy_b, [-1])
    emb_b = paddle.static.nn.embedding(
S
slf12 已提交
129
        input=analogy_b, size=[vocab_size, emb_size], param_attr="emb")
B
Bai Yifan 已提交
130 131
    analogy_c = paddle.reshape(analogy_c, [-1])
    emb_c = paddle.static.nn.embedding(
S
slf12 已提交
132
        input=analogy_c, size=[vocab_size, emb_size], param_attr="emb")
B
Bai Yifan 已提交
133
    target = paddle.add(paddle.add(emb_b, -emb_a), emb_c)
Z
zhouzj 已提交
134
    emb_all_label_l2 = F.normalize(emb_all_label, p=2, axis=1)
W
whs 已提交
135
    dist = paddle.matmul(x=target, y=emb_all_label_l2, transpose_y=True)
B
Bai Yifan 已提交
136
    values, pred_idx = paddle.topk(x=dist, k=4)
S
slf12 已提交
137
    return values, pred_idx