attention_lstm_op.cc 20.1 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/attention_lstm_op.h"
16

T
tensor-tang 已提交
17
#include <string>
18

T
tensor-tang 已提交
19
#include "paddle/fluid/platform/cpu_info.h"
20
#include "paddle/phi/kernels/funcs/blas/blas.h"
F
Feiyu Chan 已提交
21
#include "paddle/phi/kernels/funcs/cpu_vec.h"
22
#include "paddle/phi/kernels/funcs/fc_functor.h"
23

T
tensor-tang 已提交
24 25 26
namespace paddle {
namespace operators {

27
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
28 29
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
30 31 32 33 34 35 36
  OP_INOUT_CHECK(
      ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight", "AttentionLstm");
  OP_INOUT_CHECK(
      ctx->HasInput("LSTMBias"), "Input", "LSTMBias", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"),
                 "Input",
                 "AttentionWeight",
37
                 "AttentionLstm");
T
tensor-tang 已提交
38

39 40
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
41 42 43
  OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"),
                 "Output",
                 "AttentionedX",
44
                 "AttentionLstm");
45 46 47
  OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"),
                 "Output",
                 "AttentionFCOut",
48 49
                 "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
50 51
  OP_INOUT_CHECK(
      ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT", "AttentionLstm");
T
tensor-tang 已提交
52 53

  auto x_dims = ctx->GetInputDim("X");
T
tensor-tang 已提交
54
  const int M = x_dims[1];
55 56
  PADDLE_ENFORCE_EQ(x_dims.size(),
                    2,
57 58 59
                    platform::errors::InvalidArgument(
                        "Expected input(X)'s dimension is 2. But received %d.",
                        x_dims.size()));
T
tensor-tang 已提交
60

T
tensor-tang 已提交
61 62
  auto w_dims = ctx->GetInputDim("LSTMWeight");
  const int D = w_dims[1] / 4;
63
  PADDLE_ENFORCE_EQ(
64 65
      w_dims.size(),
      2,
66 67 68
      platform::errors::InvalidArgument(
          "Expected input(LSTMWeight)'s dimension is 2.But received %d.",
          w_dims.size()));
69
  PADDLE_ENFORCE_EQ(
70 71
      w_dims[0],
      D + M,
72 73
      platform::errors::InvalidArgument(
          "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
T
tensor-tang 已提交
74 75

  auto b_dims = ctx->GetInputDim("LSTMBias");
76
  PADDLE_ENFORCE_EQ(
77 78
      b_dims.size(),
      2,
79
      platform::errors::InvalidArgument("Input(LSTMBias)'s rank must be 2."));
80 81
  PADDLE_ENFORCE_EQ(b_dims[0],
                    1,
82 83
                    platform::errors::InvalidArgument(
                        "LSTMBias dims should be 1 x %d.", 4 * D));
84 85
  PADDLE_ENFORCE_EQ(b_dims[1],
                    4 * D,
86 87
                    platform::errors::InvalidArgument(
                        "LSTMBias dims should be 1 x %d.", 4 * D));
T
tensor-tang 已提交
88 89

  auto c_dims = ctx->GetInputDim("C0");
90
  PADDLE_ENFORCE_EQ(
91 92
      c_dims.size(),
      2,
93
      platform::errors::InvalidArgument("Input(C0)'s rank must be 2."));
T
tensor-tang 已提交
94
  if (ctx->IsRuntime()) {
95
    PADDLE_ENFORCE_EQ(
96 97
        c_dims[1],
        D,
98
        platform::errors::InvalidArgument("C0 dims should be N x %d.", D));
T
tensor-tang 已提交
99 100
  }

101
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
102
    auto h_dims = ctx->GetInputDim("H0");
103
    PADDLE_ENFORCE_EQ(
104 105
        h_dims.size(),
        2UL,
106 107 108
        platform::errors::InvalidArgument(
            "Expected input(H0)'s dimension is 2. But received %d.",
            h_dims.size()));
T
update  
tensor-tang 已提交
109
    if (ctx->IsRuntime() ||
110
        (phi::product(c_dims) > 0 && phi::product(h_dims) > 0)) {
111 112
      PADDLE_ENFORCE_EQ(h_dims,
                        c_dims,
113 114 115
                        platform::errors::InvalidArgument(
                            "The dimension of Input(H0) and Input(C0) "
                            "should be the same."));
T
update  
tensor-tang 已提交
116
    }
T
tensor-tang 已提交
117 118
  }

T
tensor-tang 已提交
119
  auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
120 121
  PADDLE_ENFORCE_EQ(atten_w_dims.size(),
                    2,
122 123
                    platform::errors::InvalidArgument(
                        "Input(AttentionWeight)'s rank must be 2."));
124 125
  PADDLE_ENFORCE_EQ(atten_w_dims[0],
                    M + D,
126
                    platform::errors::InvalidArgument(
127 128
                        "Expected `AttentionWeight` shape is [(%d + %d), 1]. "
                        "But received shape = [%d, 1], shape[0] is not %d.",
129 130 131 132 133 134
                        M,
                        D,
                        atten_w_dims[0],
                        M + D));
  PADDLE_ENFORCE_EQ(atten_w_dims[1],
                    1,
135 136
                    platform::errors::InvalidArgument(
                        "AttentionWeight shapes must be (%d + %d) * 1.", M, D));
T
tensor-tang 已提交
137

138
  if (ctx->HasInput("AttentionBias")) {
T
tensor-tang 已提交
139
    auto atten_b_dims = ctx->GetInputDim("AttentionBias");
140 141
    PADDLE_ENFORCE_EQ(atten_b_dims.size(),
                      2,
142 143
                      platform::errors::InvalidArgument(
                          "Input(AttentionBias)'s rank must be 2."));
144 145
    PADDLE_ENFORCE_EQ(atten_b_dims[0],
                      1,
146 147
                      platform::errors::InvalidArgument(
                          "AttentionBias shapes must be 1 * 1."));
148 149
    PADDLE_ENFORCE_EQ(atten_b_dims[1],
                      1,
150 151
                      platform::errors::InvalidArgument(
                          "AttentionBias shapes must be 1 * 1."));
T
tensor-tang 已提交
152 153
  }

154
  if (ctx->HasInput("AttentionScalar")) {
T
tensor-tang 已提交
155
    auto dims = ctx->GetInputDim("AttentionScalar");
156 157
    PADDLE_ENFORCE_EQ(dims.size(),
                      2,
158 159
                      platform::errors::InvalidArgument(
                          "Input(AttentionScalar)'s rank must be 2."));
160 161
    PADDLE_ENFORCE_EQ(dims[0],
                      1,
162 163
                      platform::errors::InvalidArgument(
                          "AttentionScalar shapes must be 1 * 1."));
164 165
    PADDLE_ENFORCE_EQ(dims[1],
                      1,
166 167
                      platform::errors::InvalidArgument(
                          "AttentionScalar shapes must be 1 * 1."));
T
tensor-tang 已提交
168 169
  }

170
  if (ctx->HasInput("AttentionScalarBias")) {
T
tensor-tang 已提交
171
    auto dims = ctx->GetInputDim("AttentionScalarBias");
172 173 174
    OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"),
                   "Input",
                   "AttentionScalar",
175
                   "AttentionLstm");
176 177
    PADDLE_ENFORCE_EQ(dims.size(),
                      2,
178 179
                      platform::errors::InvalidArgument(
                          "Input(AttentionScalarBias)'s rank must be 2."));
180 181
    PADDLE_ENFORCE_EQ(dims[0],
                      1,
182 183
                      platform::errors::InvalidArgument(
                          "AttentionScalarBias shapes must be 1 * 1."));
184 185
    PADDLE_ENFORCE_EQ(dims[1],
                      1,
186 187
                      platform::errors::InvalidArgument(
                          "AttentionScalarBias shapes must be 1 * 1."));
T
tensor-tang 已提交
188 189 190
  }

  framework::DDim out_dims({x_dims[0], D});
T
tensor-tang 已提交
191 192
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
193 194 195 196
  ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
  ctx->SetOutputDim("LSTMX", {1, M});
  ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
  // AttentionFCOut should be reshape as (maxseqlen,1) in runtime
T
tensor-tang 已提交
197 198 199 200
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
}

201
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
T
tensor-tang 已提交
202
    const framework::ExecutionContext& ctx) const {
203 204
  return framework::OpKernelType(
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
T
tensor-tang 已提交
205 206
}

207
void AttentionLSTMOpMaker::Make() {
T
tensor-tang 已提交
208 209 210 211 212
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
           "variable-time length input sequence. The underlying tensor in "
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
213 214 215 216 217
  AddInput("C0",
           "(Tensor) LSTM C0"
           "This is a tensor with shape (N x D), where N is the batch size, D "
           "is the gate size."
           "C0 is necessary because of attention.");
T
tensor-tang 已提交
218
  AddInput("H0",
219 220 221
           "(Tensor, optional) LSTM H0"
           "This is a tensor with shape (N x D), where N is the "
           "batch size and D is the gate size.")
T
tensor-tang 已提交
222
      .AsDispensable();
223 224 225 226
  AddInput("AttentionWeight",
           "(Tensor) the weights of attention fc. Always relu the fc result."
           "The shape is ((M+D) x 1), where M is the dim size of x, D is the "
           "gate size of LSTM.");
T
tensor-tang 已提交
227 228
  AddInput("AttentionBias",
           "(Tensor, optional) the bias of attention fc."
229 230 231 232 233 234 235 236 237 238
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalar",
           "(Tensor, optional) the scalar on the result of attentioned fc. "
           "Always relu the Scalar."
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalarBias",
           "(Tensor, optional) the scalar bias of attention fc."
           "The shape is (1 x 1)")
T
tensor-tang 已提交
239
      .AsDispensable();
240 241 242 243 244 245 246 247 248 249
  AddInput("LSTMWeight",
           "(Tensor) the combined weight of LSTM"
           " - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
           "is the dim size of x"
           " - Weight = {W_forget, W_input, W_output, W_cell}");
  AddInput("LSTMBias",
           "(Tensor) the combined bias of LSTM, shape (1x4D)."
           "Note: we should add the bias of hidden and context accorindg to "
           "the same gate: "
           "{B_forget, B_input, B_output, B_cell}");
T
tensor-tang 已提交
250 251 252 253 254 255
  AddOutput("Hidden",
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
256 257 258 259
  AddOutput("AttentionedX",
            "(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size.")
T
tensor-tang 已提交
260
      .AsIntermediate();
261 262
  AddOutput("AttentionFCOut",
            "(Tensor) (max_seq_len, 1), compute at each step.")
T
tensor-tang 已提交
263
      .AsIntermediate();
264 265 266 267 268 269 270 271 272
  AddOutput("LSTMX",
            "(Tensor) the input X of LSTM for each step."
            "Shape is (1 x M), where M is the x frame size")
      .AsIntermediate();
  AddOutput(
      "LSTMOUT",
      "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
      "Shape is (1 x 4D), where M is the x frame size")
      .AsIntermediate();
T
tensor-tang 已提交
273 274 275 276 277
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
278
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
279 280
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
翟飞跃 已提交
281
                       "The activation for cell output, `tanh` by default.")
T
tensor-tang 已提交
282
      .SetDefault("tanh")
283
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
284 285 286 287 288
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
289
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
290
  AddComment(R"DOC(
291 292 293 294 295 296 297 298 299
Attention Long-Short Term Memory (LSTM) Operator.

Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))

tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu

fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu

300
dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M)
301 302 303 304

LSTM part:
use lstm_x_t as input and compute as standard LSTM.

T
tensor-tang 已提交
305 306 307
)DOC");
}

308 309 310 311
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
  if (bias) {
F
Feiyu Chan 已提交
312 313
    phi::funcs::vec_add_bias<T, platform::avx>(n, *bias, x, y);
    phi::funcs::vec_relu<T, platform::avx>(n, y, y);
314
  } else {
F
Feiyu Chan 已提交
315
    phi::funcs::vec_relu<T, platform::avx>(n, x, y);
316 317 318
  }
}

T
tensor-tang 已提交
319 320
template <typename T>
inline void vec_softmax(const int n, const T* x, T* y) {
321 322 323 324 325
  T scalar = x[0];
  // max
  for (int i = 1; i < n; ++i) {
    scalar = scalar < x[i] ? x[i] : scalar;
  }
F
Feiyu Chan 已提交
326 327
  phi::funcs::vec_add_bias<T, platform::avx>(n, -scalar, x, y);  // sub
  phi::funcs::vec_exp<T>(n, y, y);                               // exp
328 329 330 331 332
  // sum
  scalar = T(0);
  for (int i = 0; i < n; ++i) {
    scalar += y[i];
  }
F
Feiyu Chan 已提交
333
  phi::funcs::vec_scal<T>(n, static_cast<T>(1) / scalar, y);  // scale
334 335
}

T
tensor-tang 已提交
336
template <typename T>
337
class AttentionLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
338 339
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
L
Leo Chen 已提交
340
    using DeviceContext = phi::CPUContext;
341 342

    auto* x = ctx.Input<LoDTensor>("X");
343 344 345 346 347 348 349 350 351
    auto* h0 = ctx.Input<phi::DenseTensor>("H0");
    auto* c0 = ctx.Input<phi::DenseTensor>("C0");
    auto* atten_w = ctx.Input<phi::DenseTensor>("AttentionWeight");
    auto* atten_b = ctx.Input<phi::DenseTensor>("AttentionBias");
    auto* atten_scalar = ctx.Input<phi::DenseTensor>("AttentionScalar");
    auto* atten_scalar_bias =
        ctx.Input<phi::DenseTensor>("AttentionScalarBias");
    auto* lstm_w = ctx.Input<phi::DenseTensor>("LSTMWeight");
    auto* lstm_b = ctx.Input<phi::DenseTensor>("LSTMBias");
352 353 354

    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
355 356 357 358
    auto* atted_x = ctx.Output<phi::DenseTensor>("AttentionedX");
    auto* fc_out = ctx.Output<phi::DenseTensor>("AttentionFCOut");
    auto* lstm_x = ctx.Output<phi::DenseTensor>("LSTMX");
    auto* lstm_out = ctx.Output<phi::DenseTensor>("LSTMOUT");
T
tensor-tang 已提交
359 360 361 362 363

    // some shape should be reshape here since infershape can not get lod info
    auto x_lod = x->lod();
    const int N = x_lod[0].size() - 1;  // batch size
    auto x_dims = x->dims();            // T x M
T
tensor-tang 已提交
364 365 366 367
    auto w_dims = lstm_w->dims();       // (D+M) x 4D
    const int total_T = x_dims[0];
    const int M = x_dims[1];      // x frame size
    const int D = w_dims[1] / 4;  // gate frame size
T
tensor-tang 已提交
368 369 370 371 372 373 374 375
    const int D2 = D * 2;
    const int D3 = D * 3;
    const int D4 = w_dims[1];
    int max_seq_len = x_lod[0][1];
    for (int i = 1; i < N; ++i) {
      int len = x_lod[0][i + 1] - x_lod[0][i];
      max_seq_len = max_seq_len < len ? len : max_seq_len;
    }
376
    PADDLE_ENFORCE_EQ(
377 378
        x_lod.size(),
        1UL,
379
        platform::errors::InvalidArgument("Input(X)'s lod size must be 1."));
380
    PADDLE_ENFORCE_EQ(
381 382
        c0->dims()[0],
        N,
383
        platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
T
tensor-tang 已提交
384
    fc_out->Resize({max_seq_len, 1});
T
tensor-tang 已提交
385

386
    std::function<void(const int, const T*, T*)> act_gate, act_cell, act_cand;
T
tensor-tang 已提交
387 388 389
    auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
    auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
    auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
T
tensor-tang 已提交
390
    if (platform::MayIUse(platform::avx)) {
F
Feiyu Chan 已提交
391
      phi::funcs::VecActivations<T, platform::avx> act_functor;
T
tensor-tang 已提交
392 393 394 395
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    } else {
F
Feiyu Chan 已提交
396
      phi::funcs::VecActivations<T, platform::isa_any> act_functor;
T
tensor-tang 已提交
397 398 399 400
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    }
T
tensor-tang 已提交
401

T
tensor-tang 已提交
402
    const T* x_data = x->data<T>();
T
tensor-tang 已提交
403
    const T* h0_data = h0 ? h0->data<T>() : NULL;
404 405 406 407 408 409 410 411 412
    const T* c0_data = c0->data<T>();
    const T* lstm_w_data = lstm_w->data<T>();
    const T* lstm_b_data = lstm_b->data<T>();
    const T* atten_w_data = atten_w->data<T>();
    const T* atten_b_data = atten_b ? atten_b->data<T>() : NULL;
    const T* atten_scalar_data = atten_scalar ? atten_scalar->data<T>() : NULL;
    const T* atten_scalar_bias_data =
        atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;

T
tensor-tang 已提交
413 414 415 416 417 418
    T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
    T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
    T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
    T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
    T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
    T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
419

L
Leo Chen 已提交
420
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
421

422
    // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
L
Leo Chen 已提交
423
    auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
424
    phi::funcs::FCFunctor<DeviceContext, T> fc;
425 426 427 428 429 430 431
    fc(dev_ctx,
       total_T,
       1,
       M,
       x_data,
       atten_w_data,
       atted_x_data,
432
       atten_b_data);
433

T
tensor-tang 已提交
434
    const T* cur_atten_x_data = atted_x_data;
435 436 437 438 439
    const T* cur_x_data = x_data;
    const T* prev_cell_data = NULL;
    const T* prev_hidden_data = NULL;
    T* cur_cell_out_data = cell_out_data;
    T* cur_hidden_out_data = hidden_out_data;
T
tensor-tang 已提交
440
    for (int i = 0; i < N; ++i) {
T
tensor-tang 已提交
441
      int seq_len = x_lod[0][i + 1] - x_lod[0][i];
442
      prev_cell_data = c0_data + i * D;
T
tensor-tang 已提交
443
      prev_hidden_data = h0_data ? h0_data + i * D : NULL;
444
      for (int step = 0; step < seq_len; ++step) {
T
tensor-tang 已提交
445 446
        /// 1. compute attention vector
        // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T
tensor-tang 已提交
447
        T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
T
tensor-tang 已提交
448 449 450
        // 1b. add cell bias and relu
        bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
        // 1c. fc scalar
451
        if (atten_scalar_data) {
T
tensor-tang 已提交
452
          blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
453 454
          bias_relu<T>(
              seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data);
455
        }
T
tensor-tang 已提交
456
        // 1d. softmax
T
tensor-tang 已提交
457
        vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
458
        // mul x(seq_len*M) and sum pool
459
        fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data);
460

T
tensor-tang 已提交
461
        /// 2. compute LSTM step
462 463 464 465 466
        // lstm weight : concat[forget , input , output , tilde]
        // shape : (D + M) x (4 * D)
        // fc inputX(1xM) * weightX(M*(4D))  => 1 x 4D
        blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
        if (prev_hidden_data) {
467 468 469 470 471 472 473 474 475 476 477 478 479
          blas.GEMM(CblasNoTrans,
                    CblasNoTrans,
                    1,
                    D4,
                    D,
                    static_cast<T>(1),
                    prev_hidden_data,
                    D,
                    lstm_w_data,
                    D4,
                    static_cast<T>(1),
                    lstm_out_data,
                    D4);
480 481 482 483 484
        }
        // since input is 1xM, so can use add bias
        blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);

        // gate act: sigmoid
485
        act_gate(D3, lstm_out_data, lstm_out_data);
486
        // candicate act: tanh
487
        act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
488 489 490 491 492

        // a = forget * prev_cell
        blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);

        // b = input * tilde
T
tensor-tang 已提交
493
        blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
494 495 496 497 498

        // cell_out = a + b
        blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);

        // state act tanh(cell_out) * output_gate
499
        act_cell(D, cur_cell_out_data, lstm_out_data);
T
tensor-tang 已提交
500
        blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
501

T
tensor-tang 已提交
502
        prev_hidden_data = cur_hidden_out_data;
503 504 505
        prev_cell_data = cur_cell_out_data;
        cur_cell_out_data = cur_cell_out_data + D;
        cur_hidden_out_data = cur_hidden_out_data + D;
T
tensor-tang 已提交
506
      }
507
      cur_x_data = cur_x_data + seq_len * M;
T
tensor-tang 已提交
508
      cur_atten_x_data = cur_atten_x_data + seq_len;
T
tensor-tang 已提交
509 510 511 512 513 514 515 516
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
517 518
REGISTER_OPERATOR(attention_lstm,
                  ops::AttentionLSTMOp,
519
                  ops::AttentionLSTMOpMaker);
T
tensor-tang 已提交
520

521 522
REGISTER_OP_CPU_KERNEL(attention_lstm,
                       ops::AttentionLSTMKernel<float>,
T
tensor-tang 已提交
523
                       ops::AttentionLSTMKernel<double>);