seq2seq_attn.py 11.2 KB
Newer Older
Q
qingqing01 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Q
qingqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import numpy as np

18
import paddle
Q
qingqing01 已提交
19 20 21 22
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layers import BeamSearchDecoder

Q
qingqing01 已提交
23
from paddle.text import RNNCell, RNN, DynamicDecode
Q
qingqing01 已提交
24 25


26
class ConvBNPool(paddle.nn.Layer):
Q
qingqing01 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    def __init__(self,
                 in_ch,
                 out_ch,
                 act="relu",
                 is_test=False,
                 pool=True,
                 use_cudnn=True):
        super(ConvBNPool, self).__init__()
        self.pool = pool

        filter_size = 3
        std = (2.0 / (filter_size**2 * in_ch))**0.5
        param_0 = fluid.ParamAttr(
            initializer=fluid.initializer.Normal(0.0, std))

        std = (2.0 / (filter_size**2 * out_ch))**0.5
        param_1 = fluid.ParamAttr(
            initializer=fluid.initializer.Normal(0.0, std))

        self.conv0 = fluid.dygraph.Conv2D(
            in_ch,
            out_ch,
            3,
            padding=1,
            param_attr=param_0,
            bias_attr=False,
            act=None,
            use_cudnn=use_cudnn)
        self.bn0 = fluid.dygraph.BatchNorm(out_ch, act=act)
        self.conv1 = fluid.dygraph.Conv2D(
            out_ch,
            out_ch,
            filter_size=3,
            padding=1,
            param_attr=param_1,
            bias_attr=False,
            act=None,
            use_cudnn=use_cudnn)
        self.bn1 = fluid.dygraph.BatchNorm(out_ch, act=act)

        if self.pool:
            self.pool = fluid.dygraph.Pool2D(
                pool_size=2,
                pool_type='max',
                pool_stride=2,
                use_cudnn=use_cudnn,
                ceil_mode=True)

    def forward(self, inputs):
        out = self.conv0(inputs)
        out = self.bn0(out)
        out = self.conv1(out)
        out = self.bn1(out)
        if self.pool:
            out = self.pool(out)
        return out


85
class CNN(paddle.nn.Layer):
Q
qingqing01 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    def __init__(self, in_ch=1, is_test=False):
        super(CNN, self).__init__()
        self.conv_bn1 = ConvBNPool(in_ch, 16)
        self.conv_bn2 = ConvBNPool(16, 32)
        self.conv_bn3 = ConvBNPool(32, 64)
        self.conv_bn4 = ConvBNPool(64, 128, pool=False)

    def forward(self, inputs):
        conv = self.conv_bn1(inputs)
        conv = self.conv_bn2(conv)
        conv = self.conv_bn3(conv)
        conv = self.conv_bn4(conv)
        return conv


class GRUCell(RNNCell):
    def __init__(self,
                 input_size,
                 hidden_size,
                 param_attr=None,
                 bias_attr=None,
                 gate_activation='sigmoid',
                 candidate_activation='tanh',
                 origin_mode=False):
        super(GRUCell, self).__init__()
        self.hidden_size = hidden_size
        self.fc_layer = fluid.dygraph.Linear(
            input_size,
            hidden_size * 3,
            param_attr=param_attr,
            bias_attr=False)

        self.gru_unit = fluid.dygraph.GRUUnit(
            hidden_size * 3,
            param_attr=param_attr,
            bias_attr=bias_attr,
            activation=candidate_activation,
            gate_activation=gate_activation,
            origin_mode=origin_mode)

    def forward(self, inputs, states):
        # step_outputs, new_states = cell(step_inputs, states)
        # for GRUCell, `step_outputs` and `new_states` both are hidden
        x = self.fc_layer(inputs)
        hidden, _, _ = self.gru_unit(x, states)
        return hidden, hidden

    @property
    def state_shape(self):
        return [self.hidden_size]


138
class Encoder(paddle.nn.Layer):
Q
qingqing01 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
    def __init__(
            self,
            in_channel=1,
            rnn_hidden_size=200,
            decoder_size=128,
            is_test=False, ):
        super(Encoder, self).__init__()
        self.rnn_hidden_size = rnn_hidden_size

        self.backbone = CNN(in_ch=in_channel, is_test=is_test)

        para_attr = fluid.ParamAttr(
            initializer=fluid.initializer.Normal(0.0, 0.02))
        bias_attr = fluid.ParamAttr(
            initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
        self.gru_fwd = RNN(cell=GRUCell(
            input_size=128 * 6,
            hidden_size=rnn_hidden_size,
            param_attr=para_attr,
            bias_attr=bias_attr,
            candidate_activation='relu'),
                           is_reverse=False,
                           time_major=False)
        self.gru_bwd = RNN(cell=GRUCell(
            input_size=128 * 6,
            hidden_size=rnn_hidden_size,
            param_attr=para_attr,
            bias_attr=bias_attr,
            candidate_activation='relu'),
                           is_reverse=True,
                           time_major=False)
        self.encoded_proj_fc = fluid.dygraph.Linear(
            rnn_hidden_size * 2, decoder_size, bias_attr=False)

    def forward(self, inputs):
        conv_features = self.backbone(inputs)
        conv_features = fluid.layers.transpose(
            conv_features, perm=[0, 3, 1, 2])

        n, w, c, h = conv_features.shape
        seq_feature = fluid.layers.reshape(conv_features, [0, -1, c * h])

        gru_fwd, _ = self.gru_fwd(seq_feature)
        gru_bwd, _ = self.gru_bwd(seq_feature)

        encoded_vector = fluid.layers.concat(input=[gru_fwd, gru_bwd], axis=2)
        encoded_proj = self.encoded_proj_fc(encoded_vector)
        return gru_bwd, encoded_vector, encoded_proj


189
class Attention(paddle.nn.Layer):
Q
qingqing01 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    """
    Neural Machine Translation by Jointly Learning to Align and Translate.
    https://arxiv.org/abs/1409.0473
    """

    def __init__(self, decoder_size):
        super(Attention, self).__init__()
        self.fc1 = fluid.dygraph.Linear(
            decoder_size, decoder_size, bias_attr=False)
        self.fc2 = fluid.dygraph.Linear(decoder_size, 1, bias_attr=False)

    def forward(self, encoder_vec, encoder_proj, decoder_state):
        # alignment model, single-layer multilayer perceptron
        decoder_state = self.fc1(decoder_state)
        decoder_state = fluid.layers.unsqueeze(decoder_state, [1])

        e = fluid.layers.elementwise_add(encoder_proj, decoder_state)
        e = fluid.layers.tanh(e)

        att_scores = self.fc2(e)
        att_scores = fluid.layers.squeeze(att_scores, [2])
        att_scores = fluid.layers.softmax(att_scores)

        context = fluid.layers.elementwise_mul(
            x=encoder_vec, y=att_scores, axis=0)
        context = fluid.layers.reduce_sum(context, dim=1)
        return context


class DecoderCell(RNNCell):
    def __init__(self, encoder_size=200, decoder_size=128):
        super(DecoderCell, self).__init__()
        self.attention = Attention(decoder_size)
        self.gru_cell = GRUCell(
            input_size=encoder_size * 2 + decoder_size,
            hidden_size=decoder_size)

    def forward(self, current_word, states, encoder_vec, encoder_proj):
        context = self.attention(encoder_vec, encoder_proj, states)
        decoder_inputs = fluid.layers.concat([current_word, context], axis=1)
        hidden, _ = self.gru_cell(decoder_inputs, states)
        return hidden, hidden


234
class Decoder(paddle.nn.Layer):
Q
qingqing01 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    def __init__(self, num_classes, emb_dim, encoder_size, decoder_size):
        super(Decoder, self).__init__()
        self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size))
        self.fc = fluid.dygraph.Linear(
            decoder_size, num_classes + 2, act='softmax')

    def forward(self, target, initial_states, encoder_vec, encoder_proj):
        out, _ = self.decoder_attention(
            target,
            initial_states=initial_states,
            encoder_vec=encoder_vec,
            encoder_proj=encoder_proj)
        pred = self.fc(out)
        return pred


251
class Seq2SeqAttModel(paddle.nn.Layer):
Q
qingqing01 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    def __init__(
            self,
            in_channle=1,
            encoder_size=200,
            decoder_size=128,
            emb_dim=128,
            num_classes=None, ):
        super(Seq2SeqAttModel, self).__init__()
        self.encoder = Encoder(in_channle, encoder_size, decoder_size)
        self.fc = fluid.dygraph.Linear(
            input_dim=encoder_size,
            output_dim=decoder_size,
            bias_attr=False,
            act='relu')
        self.embedding = fluid.dygraph.Embedding(
            [num_classes + 2, emb_dim], dtype='float32')
        self.decoder = Decoder(num_classes, emb_dim, encoder_size,
                               decoder_size)

    def forward(self, inputs, target):
        gru_backward, encoded_vector, encoded_proj = self.encoder(inputs)
        decoder_boot = self.fc(gru_backward[:, 0])
        trg_embedding = self.embedding(target)
        prediction = self.decoder(trg_embedding, decoder_boot, encoded_vector,
                                  encoded_proj)
        return prediction


class Seq2SeqAttInferModel(Seq2SeqAttModel):
    def __init__(
            self,
            in_channle=1,
            encoder_size=200,
            decoder_size=128,
            emb_dim=128,
            num_classes=None,
            beam_size=0,
            bos_id=0,
            eos_id=1,
            max_out_len=20, ):
        super(Seq2SeqAttInferModel, self).__init__(
            in_channle, encoder_size, decoder_size, emb_dim, num_classes)
        self.beam_size = beam_size
        # dynamic decoder for inference
        decoder = BeamSearchDecoder(
            self.decoder.decoder_attention.cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=self.embedding,
            output_fn=self.decoder.fc)
        self.infer_decoder = DynamicDecode(
            decoder, max_step_num=max_out_len, is_test=True)

    def forward(self, inputs, *args):
        gru_backward, encoded_vector, encoded_proj = self.encoder(inputs)
        decoder_boot = self.fc(gru_backward[:, 0])

        if self.beam_size:
            # Tile the batch dimension with beam_size
            encoded_vector = BeamSearchDecoder.tile_beam_merge_with_batch(
                encoded_vector, self.beam_size)
            encoded_proj = BeamSearchDecoder.tile_beam_merge_with_batch(
                encoded_proj, self.beam_size)
        # dynamic decoding with beam search
        rs, _ = self.infer_decoder(
            inits=decoder_boot,
            encoder_vec=encoded_vector,
            encoder_proj=encoded_proj)
        return rs


324
class WeightCrossEntropy(paddle.nn.Layer):
Q
qingqing01 已提交
325
    def __init__(self):
Q
qingqing01 已提交
326
        super(WeightCrossEntropy, self).__init__()
Q
qingqing01 已提交
327

Q
qingqing01 已提交
328
    def forward(self, predict, label, mask):
Q
qingqing01 已提交
329 330 331 332
        loss = layers.cross_entropy(predict, label=label)
        loss = layers.elementwise_mul(loss, mask, axis=0)
        loss = layers.reduce_sum(loss)
        return loss