network.py 6.7 KB
Newer Older
H
hetianjian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

H
hetianjian 已提交
15 16 17 18 19 20 21
import paddle
import math
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers


22
def network(items_num, hidden_size, step, bs):
H
hetianjian 已提交
23 24
    stdv = 1.0 / math.sqrt(hidden_size)

25
    items = fluid.data(
H
hetianjian 已提交
26
        name="items",
27 28 29
        shape=[bs, -1],
        dtype="int64") #[batch_size, uniq_max]
    seq_index = fluid.data(
H
hetianjian 已提交
30
        name="seq_index",
31 32 33
        shape=[bs, -1, 2],
        dtype="int32") #[batch_size, seq_max, 2]
    last_index = fluid.data(
H
hetianjian 已提交
34
        name="last_index",
35 36 37
        shape=[bs, 2],
        dtype="int32") #[batch_size, 2]
    adj_in = fluid.data(
H
hetianjian 已提交
38
        name="adj_in",
39
        shape=[bs, -1, -1],
H
hutuxian 已提交
40
        dtype="float32") #[batch_size, seq_max, seq_max]
41
    adj_out = fluid.data(
H
hetianjian 已提交
42
        name="adj_out",
43
        shape=[bs, -1, -1],
H
hutuxian 已提交
44
        dtype="float32") #[batch_size, seq_max, seq_max]
45
    mask = fluid.data(
H
hetianjian 已提交
46
        name="mask",
47
        shape=[bs, -1, 1],
H
hutuxian 已提交
48
        dtype="float32") #[batch_size, seq_max, 1]
49
    label = fluid.data(
H
hetianjian 已提交
50
        name="label",
51
        shape=[bs, 1],
H
hutuxian 已提交
52
        dtype="int64") #[batch_size, 1]
H
hetianjian 已提交
53

54
    datas = [items, seq_index, last_index, adj_in, adj_out, mask, label]
55 56
    py_reader = fluid.io.DataLoader.from_generator(capacity=256, feed_list=datas, iterable=False)
    feed_datas = datas
57

58
    items_emb = fluid.embedding(
H
hetianjian 已提交
59 60 61 62 63 64 65 66 67
        input=items,
        param_attr=fluid.ParamAttr(
            name="emb",
            initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)),
        size=[items_num, hidden_size])  #[batch_size, uniq_max, h]

    pre_state = items_emb
    for i in range(step):
68
        pre_state = layers.reshape(x=pre_state, shape=[bs, -1, hidden_size])
H
hetianjian 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        state_in = layers.fc(
            input=pre_state,
            name="state_in",
            size=hidden_size,
            act=None,
            num_flatten_dims=2,
            param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)),
            bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)))  #[batch_size, uniq_max, h]
        state_out = layers.fc(
            input=pre_state,
            name="state_out",
            size=hidden_size,
            act=None,
            num_flatten_dims=2,
            param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)),
            bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)))  #[batch_size, uniq_max, h]

H
hetianjian 已提交
90 91
        state_adj_in = layers.matmul(adj_in, state_in)  #[batch_size, uniq_max, h]
        state_adj_out = layers.matmul(adj_out, state_out)   #[batch_size, uniq_max, h]
H
hetianjian 已提交
92 93 94 95

        gru_input = layers.concat([state_adj_in, state_adj_out], axis=2)

        gru_input = layers.reshape(x=gru_input, shape=[-1, hidden_size * 2])
H
hetianjian 已提交
96 97 98 99 100
        gru_fc = layers.fc(
            input=gru_input,
            name="gru_fc",
            size=3 * hidden_size,
            bias_attr=False)
H
hetianjian 已提交
101 102
        pre_state, _, _ = fluid.layers.gru_unit(
            input=gru_fc,
103
            hidden=layers.reshape(x=pre_state, shape=[-1, hidden_size]),
H
hetianjian 已提交
104 105
            size=3 * hidden_size)

106 107 108
    final_state = layers.reshape(pre_state, shape=[bs, -1, hidden_size])
    seq = layers.gather_nd(final_state, seq_index)
    last = layers.gather_nd(final_state, last_index)
H
hetianjian 已提交
109 110 111 112 113 114 115 116

    seq_fc = layers.fc(
        input=seq,
        name="seq_fc",
        size=hidden_size,
        bias_attr=False,
        act=None,
        num_flatten_dims=2,
H
hetianjian 已提交
117 118
        param_attr=fluid.ParamAttr(
            initializer=fluid.initializer.Uniform(
H
hutuxian 已提交
119
            low=-stdv, high=stdv)))  #[batch_size, seq_max, h]
H
hetianjian 已提交
120 121 122 123 124 125 126 127 128 129 130 131
    last_fc = layers.fc(
        input=last,
        name="last_fc",
        size=hidden_size,
        bias_attr=False,
        act=None,
        num_flatten_dims=1,
        param_attr=fluid.ParamAttr(
            initializer=fluid.initializer.Uniform(
            low=-stdv, high=stdv)))  #[bathc_size, h]

    seq_fc_t = layers.transpose(
H
hutuxian 已提交
132
        seq_fc, perm=[1, 0, 2])  #[seq_max, batch_size, h]
H
hetianjian 已提交
133
    add = layers.elementwise_add(
H
hutuxian 已提交
134
        seq_fc_t, last_fc)  #[seq_max, batch_size, h]
H
hetianjian 已提交
135 136 137 138
    b = layers.create_parameter(
        shape=[hidden_size],
        dtype='float32',
        default_initializer=fluid.initializer.Constant(value=0.0))  #[h]
H
hutuxian 已提交
139
    add = layers.elementwise_add(add, b)  #[seq_max, batch_size, h]
H
hetianjian 已提交
140

H
hutuxian 已提交
141
    add_sigmoid = layers.sigmoid(add) #[seq_max, batch_size, h] 
H
hetianjian 已提交
142
    add_sigmoid = layers.transpose(
H
hutuxian 已提交
143
        add_sigmoid, perm=[1, 0, 2])  #[batch_size, seq_max, h]
H
hetianjian 已提交
144 145 146 147 148 149 150 151 152 153

    weight = layers.fc(
        input=add_sigmoid,
        name="weight_fc",
        size=1,
        act=None,
        num_flatten_dims=2,
        bias_attr=False,
        param_attr=fluid.ParamAttr(
            initializer=fluid.initializer.Uniform(
H
hutuxian 已提交
154
                low=-stdv, high=stdv)))  #[batch_size, seq_max, 1]
H
hetianjian 已提交
155
    weight *= mask
H
hutuxian 已提交
156 157
    weight_mask = layers.elementwise_mul(seq, weight, axis=0) #[batch_size, seq_max, h]
    global_attention = layers.reduce_sum(weight_mask, dim=1) #[batch_size, h]
H
hetianjian 已提交
158 159

    final_attention = layers.concat(
160
        [global_attention, last], axis=1)  #[batch_size, 2*h]
H
hetianjian 已提交
161 162
    final_attention_fc = layers.fc(
        input=final_attention,
163
        name="final_attention_fc",
H
hetianjian 已提交
164 165 166 167 168 169 170
        size=hidden_size,
        bias_attr=False,
        act=None,
        param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
            low=-stdv, high=stdv)))  #[batch_size, h]

    all_vocab = layers.create_global_var(
171
        shape=[items_num - 1],
H
hetianjian 已提交
172 173 174 175 176
        value=0,
        dtype="int64",
        persistable=True,
        name="all_vocab")

177
    all_emb = fluid.embedding(
H
hetianjian 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191
        input=all_vocab,
        param_attr=fluid.ParamAttr(
            name="emb",
            initializer=fluid.initializer.Uniform(
                low=-stdv, high=stdv)),
        size=[items_num, hidden_size])  #[all_vocab, h]

    logits = layers.matmul(
        x=final_attention_fc, y=all_emb,
        transpose_y=True)  #[batch_size, all_vocab]
    softmax = layers.softmax_with_cross_entropy(
        logits=logits, label=label)  #[batch_size, 1]
    loss = layers.reduce_mean(softmax)  # [1]
    acc = layers.accuracy(input=logits, label=label, k=20)
192
    return loss, acc, py_reader, feed_datas