nets.py 7.6 KB
Newer Older
Q
Qiao Longfei 已提交
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
D
dongdaxiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14

D
dongdaxiang 已提交
15
import paddle.fluid as fluid
16 17 18 19 20
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io

D
dongdaxiang 已提交
21 22 23

class BowEncoder(object):
    """ bow-encoder """
24

D
dongdaxiang 已提交
25 26 27 28
    def __init__(self):
        self.param_name = ""

    def forward(self, emb):
29 30
        return nn.sequence_pool(input=emb, pool_type='sum')

D
dongdaxiang 已提交
31 32 33

class CNNEncoder(object):
    """ cnn-encoder"""
34 35

    def __init__(self,
D
dongdaxiang 已提交
36 37 38 39 40 41 42 43 44 45
                 param_name="cnn.w",
                 win_size=3,
                 ksize=128,
                 act='tanh',
                 pool_type='max'):
        self.param_name = param_name
        self.win_size = win_size
        self.ksize = ksize
        self.act = act
        self.pool_type = pool_type
46

D
dongdaxiang 已提交
47 48 49 50 51 52 53
    def forward(self, emb):
        return fluid.nets.sequence_conv_pool(
            input=emb,
            num_filters=self.ksize,
            filter_size=self.win_size,
            act=self.act,
            pool_type=self.pool_type,
D
dongdaxiang 已提交
54
            param_attr=str(self.param_name))
D
dongdaxiang 已提交
55

56

D
dongdaxiang 已提交
57 58
class GrnnEncoder(object):
    """ grnn-encoder """
59 60

    def __init__(self, param_name="grnn.w", hidden_size=128):
D
dongdaxiang 已提交
61
        self.param_name = param_name
D
dongdaxiang 已提交
62
        self.hidden_size = hidden_size
63

D
dongdaxiang 已提交
64
    def forward(self, emb):
D
dongdaxiang 已提交
65 66 67 68 69
        fc0 = nn.fc(
            input=emb, 
            size=self.hidden_size * 3, 
            param_attr=str(str(self.param_name) + "_fc")
        )
70
        gru_h = nn.dynamic_gru(
D
dongdaxiang 已提交
71
            input=fc0,
72 73
            size=self.hidden_size,
            is_reverse=False,
D
dongdaxiang 已提交
74
            param_attr=str(self.param_name))
75 76
        return nn.sequence_pool(input=gru_h, pool_type='max')

D
dongdaxiang 已提交
77

D
dongdaxiang 已提交
78 79
'''this is a very simple Encoder factory
most default argument values are used'''
80 81


D
dongdaxiang 已提交
82 83 84 85
class SimpleEncoderFactory(object):
    def __init__(self):
        pass

D
dongdaxiang 已提交
86
    ''' create an encoder through create function '''
87

D
dongdaxiang 已提交
88 89 90 91 92 93 94 95 96 97 98
    def create(self, enc_type, enc_hid_size):
        if enc_type == "bow":
            bow_encode = BowEncoder()
            return bow_encode
        elif enc_type == "cnn":
            cnn_encode = CNNEncoder(ksize=enc_hid_size)
            return cnn_encode
        elif enc_type == "gru":
            rnn_encode = GrnnEncoder(hidden_size=enc_hid_size)
            return rnn_encode

99

D
dongdaxiang 已提交
100 101
class MultiviewSimnet(object):
    """ multi-view simnet """
102 103

    def __init__(self, embedding_size, embedding_dim, hidden_size):
D
dongdaxiang 已提交
104 105
        self.embedding_size = embedding_size
        self.embedding_dim = embedding_dim
106
        self.emb_shape = [self.embedding_size, self.embedding_dim]
D
dongdaxiang 已提交
107 108 109 110 111 112 113 114 115 116
        self.hidden_size = hidden_size
        self.margin = 0.1

    def set_query_encoder(self, encoders):
        self.query_encoders = encoders

    def set_title_encoder(self, encoders):
        self.title_encoders = encoders

    def get_correct(self, x, y):
117 118
        less = tensor.cast(cf.less_than(x, y), dtype='float32')
        correct = nn.reduce_sum(less)
D
dongdaxiang 已提交
119 120 121 122
        return correct

    def train_net(self):
        # input fields for query, pos_title, neg_title
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        q_slots = [
            io.data(
                name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.query_encoders))
        ]
        pt_slots = [
            io.data(
                name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
        nt_slots = [
            io.data(
                name="nt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
D
dongdaxiang 已提交
138 139

        # lookup embedding for each slot
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        q_embs = [
            nn.embedding(
                input=query, size=self.emb_shape, param_attr="emb.w")
            for query in q_slots
        ]
        pt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in pt_slots
        ]
        nt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in nt_slots
        ]

D
dongdaxiang 已提交
156
        # encode each embedding field with encoder
157 158 159 160 161 162 163 164 165
        q_encodes = [
            self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs)
        ]
        pt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
        ]
        nt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(nt_embs)
        ]
D
dongdaxiang 已提交
166 167

        # concat multi view for query, pos_title, neg_title
168 169 170
        q_concat = nn.concat(q_encodes)
        pt_concat = nn.concat(pt_encodes)
        nt_concat = nn.concat(nt_encodes)
D
dongdaxiang 已提交
171 172

        # projection of hidden layer
173 174 175
        q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
        pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
        nt_hid = nn.fc(nt_concat, size=self.hidden_size, param_attr='t_fc.w')
D
dongdaxiang 已提交
176 177

        # cosine of hidden layers
178 179 180
        cos_pos = nn.cos_sim(q_hid, pt_hid)
        cos_neg = nn.cos_sim(q_hid, nt_hid)

D
dongdaxiang 已提交
181
        # pairwise hinge_loss
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
        loss_part1 = nn.elementwise_sub(
            tensor.fill_constant_batch_size_like(
                input=cos_pos,
                shape=[-1, 1],
                value=self.margin,
                dtype='float32'),
            cos_pos)

        loss_part2 = nn.elementwise_add(loss_part1, cos_neg)

        loss_part3 = nn.elementwise_max(
            tensor.fill_constant_batch_size_like(
                input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
            loss_part2)

        avg_cost = nn.mean(loss_part3)
D
dongdaxiang 已提交
198
        correct = self.get_correct(cos_neg, cos_pos)
D
dongdaxiang 已提交
199 200

        return q_slots + pt_slots + nt_slots, avg_cost, correct
201 202 203 204 205 206 207 208 209 210 211 212

    def pred_net(self, query_fields, pos_title_fields, neg_title_fields):
        q_slots = [
            io.data(
                name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.query_encoders))
        ]
        pt_slots = [
            io.data(
                name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
            for i in range(len(self.title_encoders))
        ]
D
dongdaxiang 已提交
213
        # lookup embedding for each slot
214 215 216 217 218 219 220 221 222 223
        q_embs = [
            nn.embedding(
                input=query, size=self.emb_shape, param_attr="emb.w")
            for query in q_slots
        ]
        pt_embs = [
            nn.embedding(
                input=title, size=self.emb_shape, param_attr="emb.w")
            for title in pt_slots
        ]
D
dongdaxiang 已提交
224
        # encode each embedding field with encoder
225 226 227 228 229 230
        q_encodes = [
            self.query_encoder[i].forward(emb) for i, emb in enumerate(q_embs)
        ]
        pt_encodes = [
            self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
        ]
D
dongdaxiang 已提交
231
        # concat multi view for query, pos_title, neg_title
232 233
        q_concat = nn.concat(q_encodes)
        pt_concat = nn.concat(pt_encodes)
D
dongdaxiang 已提交
234
        # projection of hidden layer
235 236
        q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
        pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
D
dongdaxiang 已提交
237
        # cosine of hidden layers
238
        cos = nn.cos_sim(q_hid, pt_hid)
D
dongdaxiang 已提交
239
        return cos