nets.py 7.5 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14

D
dongdaxiang 已提交
15
import paddle.fluid as fluid
16 17 18 19 20
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io

D
dongdaxiang 已提交
21 22 23

class BowEncoder(object):
    """ bow-encoder """
24

D
dongdaxiang 已提交
25 26 27 28
    def __init__(self):
        self.param_name = ""

    def forward(self, emb):
29 30
        return nn.sequence_pool(input=emb, pool_type='sum')

D
dongdaxiang 已提交
31 32 33

class CNNEncoder(object):
    """ cnn-encoder"""
34 35

    def __init__(self,
D
dongdaxiang 已提交
36 37 38 39 40 41 42 43 44 45
                 param_name="cnn.w",
                 win_size=3,
                 ksize=128,
                 act='tanh',
                 pool_type='max'):
        self.param_name = param_name
        self.win_size = win_size
        self.ksize = ksize
        self.act = act
        self.pool_type = pool_type
46

D
dongdaxiang 已提交
47 48 49 50 51 52 53 54 55
    def forward(self, emb):
        return fluid.nets.sequence_conv_pool(
            input=emb,
            num_filters=self.ksize,
            filter_size=self.win_size,
            act=self.act,
            pool_type=self.pool_type,
            attr=self.param_name)

56

D
dongdaxiang 已提交
57 58
class GrnnEncoder(object):
    """ grnn-encoder """
59 60

    def __init__(self, param_name="grnn.w", hidden_size=128):
D
dongdaxiang 已提交
61 62
        self.param_name = args
        self.hidden_size = hidden_size
63

D
dongdaxiang 已提交
64
    def forward(self, emb):
65 66 67 68 69 70 71 72
        fc0 = nn.fc(input=emb, size=self.hidden_size * 3)
        gru_h = nn.dynamic_gru(
            input=emb,
            size=self.hidden_size,
            is_reverse=False,
            attr=self.param_name)
        return nn.sequence_pool(input=gru_h, pool_type='max')

D
dongdaxiang 已提交
73

D
dongdaxiang 已提交
74 75
'''this is a very simple Encoder factory
most default argument values are used'''
76 77


D
dongdaxiang 已提交
78 79 80 81
class SimpleEncoderFactory(object):
    def __init__(self):
        pass

D
dongdaxiang 已提交
82
    ''' create an encoder through create function '''
83

D
dongdaxiang 已提交
84 85 86 87 88 89 90 91 92 93 94
    def create(self, enc_type, enc_hid_size):
        if enc_type == "bow":
            bow_encode = BowEncoder()
            return bow_encode
        elif enc_type == "cnn":
            cnn_encode = CNNEncoder(ksize=enc_hid_size)
            return cnn_encode
        elif enc_type == "gru":
            rnn_encode = GrnnEncoder(hidden_size=enc_hid_size)
            return rnn_encode

95

D
dongdaxiang 已提交
96 97
class MultiviewSimnet(object):
    """ multi-view simnet """
98 99

    def __init__(self, embedding_size, embedding_dim, hidden_size):
D
dongdaxiang 已提交
100 101
        self.embedding_size = embedding_size
        self.embedding_dim = embedding_dim
102
        self.emb_shape = [self.embedding_size, self.embedding_dim]
D
dongdaxiang 已提交
103 104 105 106 107 108 109 110 111 112
        self.hidden_size = hidden_size
        self.margin = 0.1

    def set_query_encoder(self, encoders):
        self.query_encoders = encoders

    def set_title_encoder(self, encoders):
        self.title_encoders = encoders

    def get_correct(self, x, y):
113 114
        less = tensor.cast(cf.less_than(x, y), dtype='float32')
        correct = nn.reduce_sum(less)
D
dongdaxiang 已提交
115 116 117 118
        return correct

    def train_net(self):
        # input fields for query, pos_title, neg_title
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
        q_slots = [
            io.data(
                name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.query_encoders))
        ]
        pt_slots = [
            io.data(
                name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
        nt_slots = [
            io.data(
                name="nt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
D
dongdaxiang 已提交
134 135

        # lookup embedding for each slot
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        q_embs = [
            nn.embedding(
                input=query, size=self.emb_shape, param_attr="emb.w")
            for query in q_slots
        ]
        pt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in pt_slots
        ]
        nt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in nt_slots
        ]

D
dongdaxiang 已提交
152
        # encode each embedding field with encoder
153 154 155 156 157 158 159 160 161
        q_encodes = [
            self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs)
        ]
        pt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
        ]
        nt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(nt_embs)
        ]
D
dongdaxiang 已提交
162 163

        # concat multi view for query, pos_title, neg_title
164 165 166
        q_concat = nn.concat(q_encodes)
        pt_concat = nn.concat(pt_encodes)
        nt_concat = nn.concat(nt_encodes)
D
dongdaxiang 已提交
167 168

        # projection of hidden layer
169 170 171
        q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
        pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
        nt_hid = nn.fc(nt_concat, size=self.hidden_size, param_attr='t_fc.w')
D
dongdaxiang 已提交
172 173

        # cosine of hidden layers
174 175 176
        cos_pos = nn.cos_sim(q_hid, pt_hid)
        cos_neg = nn.cos_sim(q_hid, nt_hid)

D
dongdaxiang 已提交
177
        # pairwise hinge_loss
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        loss_part1 = nn.elementwise_sub(
            tensor.fill_constant_batch_size_like(
                input=cos_pos,
                shape=[-1, 1],
                value=self.margin,
                dtype='float32'),
            cos_pos)

        loss_part2 = nn.elementwise_add(loss_part1, cos_neg)

        loss_part3 = nn.elementwise_max(
            tensor.fill_constant_batch_size_like(
                input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
            loss_part2)

        avg_cost = nn.mean(loss_part3)
D
dongdaxiang 已提交
194 195 196
        correct = self.get_correct(cos_pos, cos_neg)

        return q_slots + pt_slots + nt_slots, avg_cost, correct
197 198 199 200 201 202 203 204 205 206 207 208

    def pred_net(self, query_fields, pos_title_fields, neg_title_fields):
        q_slots = [
            io.data(
                name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.query_encoders))
        ]
        pt_slots = [
            io.data(
                name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
D
dongdaxiang 已提交
209
        # lookup embedding for each slot
210 211 212 213 214 215 216 217 218 219
        q_embs = [
            nn.embedding(
                input=query, size=self.emb_shape, param_attr="emb.w")
            for query in q_slots
        ]
        pt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in pt_slots
        ]
D
dongdaxiang 已提交
220
        # encode each embedding field with encoder
221 222 223 224 225 226
        q_encodes = [
            self.query_encoder[i].forward(emb) for i, emb in enumerate(q_embs)
        ]
        pt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
        ]
D
dongdaxiang 已提交
227
        # concat multi view for query, pos_title, neg_title
228 229
        q_concat = nn.concat(q_encodes)
        pt_concat = nn.concat(pt_encodes)
D
dongdaxiang 已提交
230
        # projection of hidden layer
231 232
        q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
        pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
D
dongdaxiang 已提交
233
        # cosine of hidden layers
234
        cos = nn.cos_sim(q_hid, pt_hid)
D
dongdaxiang 已提交
235
        return cos