attention_lstm_op.cc 20.0 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/attention_lstm_op.h"
16

T
tensor-tang 已提交
17
#include <string>
18

T
tensor-tang 已提交
19
#include "paddle/fluid/platform/cpu_info.h"
20
#include "paddle/phi/kernels/funcs/blas/blas.h"
F
Feiyu Chan 已提交
21
#include "paddle/phi/kernels/funcs/cpu_vec.h"
22
#include "paddle/phi/kernels/funcs/fc_functor.h"
23

T
tensor-tang 已提交
24 25 26
namespace paddle {
namespace operators {

27
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
28 29
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
30 31 32 33 34 35 36
  OP_INOUT_CHECK(
      ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight", "AttentionLstm");
  OP_INOUT_CHECK(
      ctx->HasInput("LSTMBias"), "Input", "LSTMBias", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"),
                 "Input",
                 "AttentionWeight",
37
                 "AttentionLstm");
T
tensor-tang 已提交
38

39 40
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
41 42 43
  OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"),
                 "Output",
                 "AttentionedX",
44
                 "AttentionLstm");
45 46 47
  OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"),
                 "Output",
                 "AttentionFCOut",
48 49
                 "AttentionLstm");
  OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
50 51
  OP_INOUT_CHECK(
      ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT", "AttentionLstm");
T
tensor-tang 已提交
52 53

  auto x_dims = ctx->GetInputDim("X");
T
tensor-tang 已提交
54
  const int M = x_dims[1];
55 56
  PADDLE_ENFORCE_EQ(x_dims.size(),
                    2,
57 58 59
                    platform::errors::InvalidArgument(
                        "Expected input(X)'s dimension is 2. But received %d.",
                        x_dims.size()));
T
tensor-tang 已提交
60

T
tensor-tang 已提交
61 62
  auto w_dims = ctx->GetInputDim("LSTMWeight");
  const int D = w_dims[1] / 4;
63
  PADDLE_ENFORCE_EQ(
64 65
      w_dims.size(),
      2,
66 67 68
      platform::errors::InvalidArgument(
          "Expected input(LSTMWeight)'s dimension is 2.But received %d.",
          w_dims.size()));
69
  PADDLE_ENFORCE_EQ(
70 71
      w_dims[0],
      D + M,
72 73
      platform::errors::InvalidArgument(
          "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
T
tensor-tang 已提交
74 75

  auto b_dims = ctx->GetInputDim("LSTMBias");
76
  PADDLE_ENFORCE_EQ(
77 78
      b_dims.size(),
      2,
79
      platform::errors::InvalidArgument("Input(LSTMBias)'s rank must be 2."));
80 81
  PADDLE_ENFORCE_EQ(b_dims[0],
                    1,
82 83
                    platform::errors::InvalidArgument(
                        "LSTMBias dims should be 1 x %d.", 4 * D));
84 85
  PADDLE_ENFORCE_EQ(b_dims[1],
                    4 * D,
86 87
                    platform::errors::InvalidArgument(
                        "LSTMBias dims should be 1 x %d.", 4 * D));
T
tensor-tang 已提交
88 89

  auto c_dims = ctx->GetInputDim("C0");
90
  PADDLE_ENFORCE_EQ(
91 92
      c_dims.size(),
      2,
93
      platform::errors::InvalidArgument("Input(C0)'s rank must be 2."));
T
tensor-tang 已提交
94
  if (ctx->IsRuntime()) {
95
    PADDLE_ENFORCE_EQ(
96 97
        c_dims[1],
        D,
98
        platform::errors::InvalidArgument("C0 dims should be N x %d.", D));
T
tensor-tang 已提交
99 100
  }

101
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
102
    auto h_dims = ctx->GetInputDim("H0");
103
    PADDLE_ENFORCE_EQ(
104 105
        h_dims.size(),
        2UL,
106 107 108
        platform::errors::InvalidArgument(
            "Expected input(H0)'s dimension is 2. But received %d.",
            h_dims.size()));
T
update  
tensor-tang 已提交
109
    if (ctx->IsRuntime() ||
110
        (phi::product(c_dims) > 0 && phi::product(h_dims) > 0)) {
111 112
      PADDLE_ENFORCE_EQ(h_dims,
                        c_dims,
113 114 115
                        platform::errors::InvalidArgument(
                            "The dimension of Input(H0) and Input(C0) "
                            "should be the same."));
T
update  
tensor-tang 已提交
116
    }
T
tensor-tang 已提交
117 118
  }

T
tensor-tang 已提交
119
  auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
120 121
  PADDLE_ENFORCE_EQ(atten_w_dims.size(),
                    2,
122 123
                    platform::errors::InvalidArgument(
                        "Input(AttentionWeight)'s rank must be 2."));
124 125
  PADDLE_ENFORCE_EQ(atten_w_dims[0],
                    M + D,
126
                    platform::errors::InvalidArgument(
127 128
                        "Expected `AttentionWeight` shape is [(%d + %d), 1]. "
                        "But received shape = [%d, 1], shape[0] is not %d.",
129 130 131 132 133 134
                        M,
                        D,
                        atten_w_dims[0],
                        M + D));
  PADDLE_ENFORCE_EQ(atten_w_dims[1],
                    1,
135 136
                    platform::errors::InvalidArgument(
                        "AttentionWeight shapes must be (%d + %d) * 1.", M, D));
T
tensor-tang 已提交
137

138
  if (ctx->HasInput("AttentionBias")) {
T
tensor-tang 已提交
139
    auto atten_b_dims = ctx->GetInputDim("AttentionBias");
140 141
    PADDLE_ENFORCE_EQ(atten_b_dims.size(),
                      2,
142 143
                      platform::errors::InvalidArgument(
                          "Input(AttentionBias)'s rank must be 2."));
144 145
    PADDLE_ENFORCE_EQ(atten_b_dims[0],
                      1,
146 147
                      platform::errors::InvalidArgument(
                          "AttentionBias shapes must be 1 * 1."));
148 149
    PADDLE_ENFORCE_EQ(atten_b_dims[1],
                      1,
150 151
                      platform::errors::InvalidArgument(
                          "AttentionBias shapes must be 1 * 1."));
T
tensor-tang 已提交
152 153
  }

154
  if (ctx->HasInput("AttentionScalar")) {
T
tensor-tang 已提交
155
    auto dims = ctx->GetInputDim("AttentionScalar");
156 157
    PADDLE_ENFORCE_EQ(dims.size(),
                      2,
158 159
                      platform::errors::InvalidArgument(
                          "Input(AttentionScalar)'s rank must be 2."));
160 161
    PADDLE_ENFORCE_EQ(dims[0],
                      1,
162 163
                      platform::errors::InvalidArgument(
                          "AttentionScalar shapes must be 1 * 1."));
164 165
    PADDLE_ENFORCE_EQ(dims[1],
                      1,
166 167
                      platform::errors::InvalidArgument(
                          "AttentionScalar shapes must be 1 * 1."));
T
tensor-tang 已提交
168 169
  }

170
  if (ctx->HasInput("AttentionScalarBias")) {
T
tensor-tang 已提交
171
    auto dims = ctx->GetInputDim("AttentionScalarBias");
172 173 174
    OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"),
                   "Input",
                   "AttentionScalar",
175
                   "AttentionLstm");
176 177
    PADDLE_ENFORCE_EQ(dims.size(),
                      2,
178 179
                      platform::errors::InvalidArgument(
                          "Input(AttentionScalarBias)'s rank must be 2."));
180 181
    PADDLE_ENFORCE_EQ(dims[0],
                      1,
182 183
                      platform::errors::InvalidArgument(
                          "AttentionScalarBias shapes must be 1 * 1."));
184 185
    PADDLE_ENFORCE_EQ(dims[1],
                      1,
186 187
                      platform::errors::InvalidArgument(
                          "AttentionScalarBias shapes must be 1 * 1."));
T
tensor-tang 已提交
188 189 190
  }

  framework::DDim out_dims({x_dims[0], D});
T
tensor-tang 已提交
191 192
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
193 194 195 196
  ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
  ctx->SetOutputDim("LSTMX", {1, M});
  ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
  // AttentionFCOut should be reshape as (maxseqlen,1) in runtime
T
tensor-tang 已提交
197 198 199 200
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
}

201
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
T
tensor-tang 已提交
202
    const framework::ExecutionContext& ctx) const {
203 204
  return framework::OpKernelType(
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
T
tensor-tang 已提交
205 206
}

207
void AttentionLSTMOpMaker::Make() {
T
tensor-tang 已提交
208 209 210 211 212
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
           "variable-time length input sequence. The underlying tensor in "
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
213 214 215 216 217
  AddInput("C0",
           "(Tensor) LSTM C0"
           "This is a tensor with shape (N x D), where N is the batch size, D "
           "is the gate size."
           "C0 is necessary because of attention.");
T
tensor-tang 已提交
218
  AddInput("H0",
219 220 221
           "(Tensor, optional) LSTM H0"
           "This is a tensor with shape (N x D), where N is the "
           "batch size and D is the gate size.")
T
tensor-tang 已提交
222
      .AsDispensable();
223 224 225 226
  AddInput("AttentionWeight",
           "(Tensor) the weights of attention fc. Always relu the fc result."
           "The shape is ((M+D) x 1), where M is the dim size of x, D is the "
           "gate size of LSTM.");
T
tensor-tang 已提交
227 228
  AddInput("AttentionBias",
           "(Tensor, optional) the bias of attention fc."
229 230 231 232 233 234 235 236 237 238
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalar",
           "(Tensor, optional) the scalar on the result of attentioned fc. "
           "Always relu the Scalar."
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalarBias",
           "(Tensor, optional) the scalar bias of attention fc."
           "The shape is (1 x 1)")
T
tensor-tang 已提交
239
      .AsDispensable();
240 241 242 243 244 245 246 247 248 249
  AddInput("LSTMWeight",
           "(Tensor) the combined weight of LSTM"
           " - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
           "is the dim size of x"
           " - Weight = {W_forget, W_input, W_output, W_cell}");
  AddInput("LSTMBias",
           "(Tensor) the combined bias of LSTM, shape (1x4D)."
           "Note: we should add the bias of hidden and context accorindg to "
           "the same gate: "
           "{B_forget, B_input, B_output, B_cell}");
T
tensor-tang 已提交
250 251 252 253 254 255
  AddOutput("Hidden",
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
256 257 258 259
  AddOutput("AttentionedX",
            "(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size.")
T
tensor-tang 已提交
260
      .AsIntermediate();
261 262
  AddOutput("AttentionFCOut",
            "(Tensor) (max_seq_len, 1), compute at each step.")
T
tensor-tang 已提交
263
      .AsIntermediate();
264 265 266 267 268 269 270 271 272
  AddOutput("LSTMX",
            "(Tensor) the input X of LSTM for each step."
            "Shape is (1 x M), where M is the x frame size")
      .AsIntermediate();
  AddOutput(
      "LSTMOUT",
      "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
      "Shape is (1 x 4D), where M is the x frame size")
      .AsIntermediate();
T
tensor-tang 已提交
273 274 275 276 277
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
278
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
279 280
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
翟飞跃 已提交
281
                       "The activation for cell output, `tanh` by default.")
T
tensor-tang 已提交
282
      .SetDefault("tanh")
283
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
284 285 286 287 288
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
289
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
290
  AddComment(R"DOC(
291 292 293 294 295 296 297 298 299 300 301 302 303 304
Attention Long-Short Term Memory (LSTM) Operator.

Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))

tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu

fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu

dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M) 

LSTM part:
use lstm_x_t as input and compute as standard LSTM.

T
tensor-tang 已提交
305 306 307
)DOC");
}

308 309 310 311
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
  if (bias) {
F
Feiyu Chan 已提交
312 313
    phi::funcs::vec_add_bias<T, platform::avx>(n, *bias, x, y);
    phi::funcs::vec_relu<T, platform::avx>(n, y, y);
314
  } else {
F
Feiyu Chan 已提交
315
    phi::funcs::vec_relu<T, platform::avx>(n, x, y);
316 317 318
  }
}

T
tensor-tang 已提交
319 320
template <typename T>
inline void vec_softmax(const int n, const T* x, T* y) {
321 322 323 324 325
  T scalar = x[0];
  // max
  for (int i = 1; i < n; ++i) {
    scalar = scalar < x[i] ? x[i] : scalar;
  }
F
Feiyu Chan 已提交
326 327
  phi::funcs::vec_add_bias<T, platform::avx>(n, -scalar, x, y);  // sub
  phi::funcs::vec_exp<T>(n, y, y);                               // exp
328 329 330 331 332
  // sum
  scalar = T(0);
  for (int i = 0; i < n; ++i) {
    scalar += y[i];
  }
F
Feiyu Chan 已提交
333
  phi::funcs::vec_scal<T>(n, static_cast<T>(1) / scalar, y);  // scale
334 335
}

T
tensor-tang 已提交
336
template <typename T>
337
class AttentionLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
338 339
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
340
    using DeviceContext = paddle::platform::CPUDeviceContext;
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357

    auto* x = ctx.Input<LoDTensor>("X");
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");
    auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
    auto* atten_b = ctx.Input<Tensor>("AttentionBias");
    auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
    auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
    auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
    auto* lstm_b = ctx.Input<Tensor>("LSTMBias");

    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
    auto* atted_x = ctx.Output<Tensor>("AttentionedX");
    auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
    auto* lstm_x = ctx.Output<Tensor>("LSTMX");
    auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
T
tensor-tang 已提交
358 359 360 361 362

    // some shape should be reshape here since infershape can not get lod info
    auto x_lod = x->lod();
    const int N = x_lod[0].size() - 1;  // batch size
    auto x_dims = x->dims();            // T x M
T
tensor-tang 已提交
363 364 365 366
    auto w_dims = lstm_w->dims();       // (D+M) x 4D
    const int total_T = x_dims[0];
    const int M = x_dims[1];      // x frame size
    const int D = w_dims[1] / 4;  // gate frame size
T
tensor-tang 已提交
367 368 369 370 371 372 373 374
    const int D2 = D * 2;
    const int D3 = D * 3;
    const int D4 = w_dims[1];
    int max_seq_len = x_lod[0][1];
    for (int i = 1; i < N; ++i) {
      int len = x_lod[0][i + 1] - x_lod[0][i];
      max_seq_len = max_seq_len < len ? len : max_seq_len;
    }
375
    PADDLE_ENFORCE_EQ(
376 377
        x_lod.size(),
        1UL,
378
        platform::errors::InvalidArgument("Input(X)'s lod size must be 1."));
379
    PADDLE_ENFORCE_EQ(
380 381
        c0->dims()[0],
        N,
382
        platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
T
tensor-tang 已提交
383
    fc_out->Resize({max_seq_len, 1});
T
tensor-tang 已提交
384

385
    std::function<void(const int, const T*, T*)> act_gate, act_cell, act_cand;
T
tensor-tang 已提交
386 387 388
    auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
    auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
    auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
T
tensor-tang 已提交
389
    if (platform::MayIUse(platform::avx)) {
F
Feiyu Chan 已提交
390
      phi::funcs::VecActivations<T, platform::avx> act_functor;
T
tensor-tang 已提交
391 392 393 394
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    } else {
F
Feiyu Chan 已提交
395
      phi::funcs::VecActivations<T, platform::isa_any> act_functor;
T
tensor-tang 已提交
396 397 398 399
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    }
T
tensor-tang 已提交
400

T
tensor-tang 已提交
401
    const T* x_data = x->data<T>();
T
tensor-tang 已提交
402
    const T* h0_data = h0 ? h0->data<T>() : NULL;
403 404 405 406 407 408 409 410 411
    const T* c0_data = c0->data<T>();
    const T* lstm_w_data = lstm_w->data<T>();
    const T* lstm_b_data = lstm_b->data<T>();
    const T* atten_w_data = atten_w->data<T>();
    const T* atten_b_data = atten_b ? atten_b->data<T>() : NULL;
    const T* atten_scalar_data = atten_scalar ? atten_scalar->data<T>() : NULL;
    const T* atten_scalar_bias_data =
        atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;

T
tensor-tang 已提交
412 413 414 415 416 417
    T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
    T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
    T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
    T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
    T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
    T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
418

419
    auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(ctx);
420

421
    // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
422
    auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
423
    phi::funcs::FCFunctor<DeviceContext, T> fc;
424 425 426 427 428 429 430
    fc(dev_ctx,
       total_T,
       1,
       M,
       x_data,
       atten_w_data,
       atted_x_data,
431
       atten_b_data);
432

T
tensor-tang 已提交
433
    const T* cur_atten_x_data = atted_x_data;
434 435 436 437 438
    const T* cur_x_data = x_data;
    const T* prev_cell_data = NULL;
    const T* prev_hidden_data = NULL;
    T* cur_cell_out_data = cell_out_data;
    T* cur_hidden_out_data = hidden_out_data;
T
tensor-tang 已提交
439
    for (int i = 0; i < N; ++i) {
T
tensor-tang 已提交
440
      int seq_len = x_lod[0][i + 1] - x_lod[0][i];
441
      prev_cell_data = c0_data + i * D;
T
tensor-tang 已提交
442
      prev_hidden_data = h0_data ? h0_data + i * D : NULL;
443
      for (int step = 0; step < seq_len; ++step) {
T
tensor-tang 已提交
444 445
        /// 1. compute attention vector
        // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T
tensor-tang 已提交
446
        T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
T
tensor-tang 已提交
447 448 449
        // 1b. add cell bias and relu
        bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
        // 1c. fc scalar
450
        if (atten_scalar_data) {
T
tensor-tang 已提交
451
          blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
452 453
          bias_relu<T>(
              seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data);
454
        }
T
tensor-tang 已提交
455
        // 1d. softmax
T
tensor-tang 已提交
456
        vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
457
        // mul x(seq_len*M) and sum pool
458
        fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data);
459

T
tensor-tang 已提交
460
        /// 2. compute LSTM step
461 462 463 464 465
        // lstm weight : concat[forget , input , output , tilde]
        // shape : (D + M) x (4 * D)
        // fc inputX(1xM) * weightX(M*(4D))  => 1 x 4D
        blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
        if (prev_hidden_data) {
466 467 468 469 470 471 472 473 474 475 476 477 478
          blas.GEMM(CblasNoTrans,
                    CblasNoTrans,
                    1,
                    D4,
                    D,
                    static_cast<T>(1),
                    prev_hidden_data,
                    D,
                    lstm_w_data,
                    D4,
                    static_cast<T>(1),
                    lstm_out_data,
                    D4);
479 480 481 482 483
        }
        // since input is 1xM, so can use add bias
        blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);

        // gate act: sigmoid
484
        act_gate(D3, lstm_out_data, lstm_out_data);
485
        // candicate act: tanh
486
        act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
487 488 489 490 491

        // a = forget * prev_cell
        blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);

        // b = input * tilde
T
tensor-tang 已提交
492
        blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
493 494 495 496 497

        // cell_out = a + b
        blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);

        // state act tanh(cell_out) * output_gate
498
        act_cell(D, cur_cell_out_data, lstm_out_data);
T
tensor-tang 已提交
499
        blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
500

T
tensor-tang 已提交
501
        prev_hidden_data = cur_hidden_out_data;
502 503 504
        prev_cell_data = cur_cell_out_data;
        cur_cell_out_data = cur_cell_out_data + D;
        cur_hidden_out_data = cur_hidden_out_data + D;
T
tensor-tang 已提交
505
      }
506
      cur_x_data = cur_x_data + seq_len * M;
T
tensor-tang 已提交
507
      cur_atten_x_data = cur_atten_x_data + seq_len;
T
tensor-tang 已提交
508 509 510 511 512 513 514 515
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
516 517
REGISTER_OPERATOR(attention_lstm,
                  ops::AttentionLSTMOp,
518
                  ops::AttentionLSTMOpMaker);
T
tensor-tang 已提交
519

520 521
REGISTER_OP_CPU_KERNEL(attention_lstm,
                       ops::AttentionLSTMKernel<float>,
T
tensor-tang 已提交
522
                       ops::AttentionLSTMKernel<double>);