attention_lstm_op.cc 17.5 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/attention_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
T
tensor-tang 已提交
18
#include "paddle/fluid/operators/math/cpu_vec.h"
T
tensor-tang 已提交
19
#include "paddle/fluid/operators/math/fc_compute.h"
T
tensor-tang 已提交
20
#include "paddle/fluid/platform/cpu_info.h"
21

T
tensor-tang 已提交
22 23 24
namespace paddle {
namespace operators {

25
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
26 27 28
  PADDLE_ENFORCE(ctx->HasInput("X"),
                 "Assert only one Input(X) of AttentionLSTM.");
  PADDLE_ENFORCE(ctx->HasInput("C0"),
T
tensor-tang 已提交
29
                 "Assert only one Input(C0) of AttentionLSTM.");
30
  PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
T
tensor-tang 已提交
31
                 "Assert only one Input(LSTMWeight) of AttentionLSTM.");
32
  PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
T
tensor-tang 已提交
33
                 "Assert only one Input(LSTMBias) of AttentionLSTM.");
34
  PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
T
tensor-tang 已提交
35 36
                 "Assert only one Input(AttentionWeight) of AttentionLSTM.");

37
  PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
T
tensor-tang 已提交
38
                 "Assert only one Output(Hidden) of AttentionLSTM.");
39
  PADDLE_ENFORCE(ctx->HasOutput("Cell"),
T
tensor-tang 已提交
40
                 "Assert only one Output(Cell) of AttentionLSTM.");
41
  PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
T
tensor-tang 已提交
42
                 "Assert only one Output(AttentionedX) of AttentionLSTM.");
43
  PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
T
tensor-tang 已提交
44
                 "Assert only one Output(AttentionFCOut) of AttentionLSTM.");
45
  PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
T
tensor-tang 已提交
46
                 "Assert only one Output(LSTMX) of AttentionLSTM.");
47
  PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
T
tensor-tang 已提交
48
                 "Assert only one Output(LSTMOUT) of AttentionLSTM.");
T
tensor-tang 已提交
49 50

  auto x_dims = ctx->GetInputDim("X");
T
tensor-tang 已提交
51
  const int M = x_dims[1];
T
tensor-tang 已提交
52 53
  PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");

T
tensor-tang 已提交
54 55 56 57
  auto w_dims = ctx->GetInputDim("LSTMWeight");
  const int D = w_dims[1] / 4;
  PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
  PADDLE_ENFORCE_EQ(w_dims[0], D + M,
58
                    "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
T
tensor-tang 已提交
59 60 61

  auto b_dims = ctx->GetInputDim("LSTMBias");
  PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
T
tensor-tang 已提交
62 63
  PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
  PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
T
tensor-tang 已提交
64 65 66 67

  auto c_dims = ctx->GetInputDim("C0");
  PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
  PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
68
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
69 70 71 72 73 74
    auto h_dims = ctx->GetInputDim("H0");
    PADDLE_ENFORCE(h_dims == c_dims,
                   "The dimension of Input(H0) and Input(C0) "
                   "should be the same.");
  }

T
tensor-tang 已提交
75 76 77 78 79 80 81
  auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
  PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
                    "Input(AttentionWeight)'s rank must be 2.");
  PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
                    "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
  PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
                    "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
82
  if (ctx->HasInput("AttentionBias")) {
T
tensor-tang 已提交
83 84 85 86 87 88 89 90 91
    auto atten_b_dims = ctx->GetInputDim("AttentionBias");
    PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
                      "Input(AttentionBias)'s rank must be 2.");
    PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
                      "AttentionBias shapes must be 1 * 1.");
    PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
                      "AttentionBias shapes must be 1 * 1.");
  }

92
  if (ctx->HasInput("AttentionScalar")) {
T
tensor-tang 已提交
93 94 95 96 97 98 99
    auto dims = ctx->GetInputDim("AttentionScalar");
    PADDLE_ENFORCE_EQ(dims.size(), 2,
                      "Input(AttentionScalar)'s rank must be 2.");
    PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
    PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
  }

100
  if (ctx->HasInput("AttentionScalarBias")) {
T
tensor-tang 已提交
101 102
    auto dims = ctx->GetInputDim("AttentionScalarBias");
    PADDLE_ENFORCE(
103
        ctx->HasInput("AttentionScalar"),
T
tensor-tang 已提交
104 105 106 107 108 109 110 111
        "AttentionScalar should not be null when have AttentionScalarBias.");
    PADDLE_ENFORCE_EQ(dims.size(), 2,
                      "Input(AttentionScalarBias)'s rank must be 2.");
    PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
    PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
  }

  framework::DDim out_dims({x_dims[0], D});
T
tensor-tang 已提交
112 113
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
114 115 116 117
  ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
  ctx->SetOutputDim("LSTMX", {1, M});
  ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
  // AttentionFCOut should be reshape as (maxseqlen,1) in runtime
T
tensor-tang 已提交
118 119 120 121
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
}

122
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
T
tensor-tang 已提交
123
    const framework::ExecutionContext& ctx) const {
M
minqiyang 已提交
124 125
  return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                 ctx.device_context());
T
tensor-tang 已提交
126 127
}

128
void AttentionLSTMOpMaker::Make() {
T
tensor-tang 已提交
129 130 131 132 133
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
           "variable-time length input sequence. The underlying tensor in "
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
134 135 136 137 138
  AddInput("C0",
           "(Tensor) LSTM C0"
           "This is a tensor with shape (N x D), where N is the batch size, D "
           "is the gate size."
           "C0 is necessary because of attention.");
T
tensor-tang 已提交
139
  AddInput("H0",
140 141 142
           "(Tensor, optional) LSTM H0"
           "This is a tensor with shape (N x D), where N is the "
           "batch size and D is the gate size.")
T
tensor-tang 已提交
143
      .AsDispensable();
144 145 146 147
  AddInput("AttentionWeight",
           "(Tensor) the weights of attention fc. Always relu the fc result."
           "The shape is ((M+D) x 1), where M is the dim size of x, D is the "
           "gate size of LSTM.");
T
tensor-tang 已提交
148 149
  AddInput("AttentionBias",
           "(Tensor, optional) the bias of attention fc."
150 151 152 153 154 155 156 157 158 159
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalar",
           "(Tensor, optional) the scalar on the result of attentioned fc. "
           "Always relu the Scalar."
           "The shape is (1 x 1)")
      .AsDispensable();
  AddInput("AttentionScalarBias",
           "(Tensor, optional) the scalar bias of attention fc."
           "The shape is (1 x 1)")
T
tensor-tang 已提交
160
      .AsDispensable();
161 162 163 164 165 166 167 168 169 170
  AddInput("LSTMWeight",
           "(Tensor) the combined weight of LSTM"
           " - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
           "is the dim size of x"
           " - Weight = {W_forget, W_input, W_output, W_cell}");
  AddInput("LSTMBias",
           "(Tensor) the combined bias of LSTM, shape (1x4D)."
           "Note: we should add the bias of hidden and context accorindg to "
           "the same gate: "
           "{B_forget, B_input, B_output, B_cell}");
T
tensor-tang 已提交
171 172 173 174 175 176
  AddOutput("Hidden",
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
177 178 179 180
  AddOutput("AttentionedX",
            "(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size.")
T
tensor-tang 已提交
181
      .AsIntermediate();
182 183
  AddOutput("AttentionFCOut",
            "(Tensor) (max_seq_len, 1), compute at each step.")
T
tensor-tang 已提交
184
      .AsIntermediate();
185 186 187 188 189 190 191 192 193
  AddOutput("LSTMX",
            "(Tensor) the input X of LSTM for each step."
            "Shape is (1 x M), where M is the x frame size")
      .AsIntermediate();
  AddOutput(
      "LSTMOUT",
      "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
      "Shape is (1 x 4D), where M is the x frame size")
      .AsIntermediate();
T
tensor-tang 已提交
194 195 196 197 198
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
199
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
200 201 202 203
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
                       "The activation for cell output, `tanh` by defalut.")
      .SetDefault("tanh")
204
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
205 206 207 208 209
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
210
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
T
tensor-tang 已提交
211
  AddComment(R"DOC(
212 213 214 215 216 217 218 219 220 221 222 223 224 225
Attention Long-Short Term Memory (LSTM) Operator.

Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))

tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu

fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu

dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M) 

LSTM part:
use lstm_x_t as input and compute as standard LSTM.

T
tensor-tang 已提交
226 227 228
)DOC");
}

229 230 231 232
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
  if (bias) {
T
tensor-tang 已提交
233 234
    math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
    math::vec_relu<T, platform::jit::avx>(n, y, y);
235
  } else {
T
tensor-tang 已提交
236
    math::vec_relu<T, platform::jit::avx>(n, x, y);
237 238 239
  }
}

T
tensor-tang 已提交
240 241
template <typename T>
inline void vec_softmax(const int n, const T* x, T* y) {
242 243 244 245 246
  T scalar = x[0];
  // max
  for (int i = 1; i < n; ++i) {
    scalar = scalar < x[i] ? x[i] : scalar;
  }
T
tensor-tang 已提交
247 248
  math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y);  // sub
  math::vec_exp<T>(n, y, y);                                    // exp
249 250 251 252 253
  // sum
  scalar = T(0);
  for (int i = 0; i < n; ++i) {
    scalar += y[i];
  }
T
tensor-tang 已提交
254
  math::vec_scal<T>(n, static_cast<T>(1) / scalar, y);  // scale
255 256
}

T
tensor-tang 已提交
257
template <typename T>
258
class AttentionLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
259 260
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
261
    using DeviceContext = paddle::platform::CPUDeviceContext;
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

    auto* x = ctx.Input<LoDTensor>("X");
    auto* h0 = ctx.Input<Tensor>("H0");
    auto* c0 = ctx.Input<Tensor>("C0");
    auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
    auto* atten_b = ctx.Input<Tensor>("AttentionBias");
    auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
    auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
    auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
    auto* lstm_b = ctx.Input<Tensor>("LSTMBias");

    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
    auto* cell_out = ctx.Output<LoDTensor>("Cell");
    auto* atted_x = ctx.Output<Tensor>("AttentionedX");
    auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
    auto* lstm_x = ctx.Output<Tensor>("LSTMX");
    auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
T
tensor-tang 已提交
279 280 281 282 283

    // some shape should be reshape here since infershape can not get lod info
    auto x_lod = x->lod();
    const int N = x_lod[0].size() - 1;  // batch size
    auto x_dims = x->dims();            // T x M
T
tensor-tang 已提交
284 285 286 287
    auto w_dims = lstm_w->dims();       // (D+M) x 4D
    const int total_T = x_dims[0];
    const int M = x_dims[1];      // x frame size
    const int D = w_dims[1] / 4;  // gate frame size
T
tensor-tang 已提交
288 289 290 291 292 293 294 295 296 297 298
    const int D2 = D * 2;
    const int D3 = D * 3;
    const int D4 = w_dims[1];
    int max_seq_len = x_lod[0][1];
    for (int i = 1; i < N; ++i) {
      int len = x_lod[0][i + 1] - x_lod[0][i];
      max_seq_len = max_seq_len < len ? len : max_seq_len;
    }
    PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
    PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
    fc_out->Resize({max_seq_len, 1});
T
tensor-tang 已提交
299

300
    std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
T
tensor-tang 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314
    auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
    auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
    auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
    if (platform::jit::MayIUse(platform::jit::avx)) {
      math::VecActivations<T, platform::jit::avx> act_functor;
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    } else {
      math::VecActivations<T, platform::jit::isa_any> act_functor;
      act_gate = act_functor(act_gate_str);
      act_cell = act_functor(act_cell_str);
      act_cand = act_functor(act_cand_str);
    }
T
tensor-tang 已提交
315

T
tensor-tang 已提交
316
    const T* x_data = x->data<T>();
T
tensor-tang 已提交
317
    const T* h0_data = h0 ? h0->data<T>() : NULL;
318 319 320 321 322 323 324 325 326
    const T* c0_data = c0->data<T>();
    const T* lstm_w_data = lstm_w->data<T>();
    const T* lstm_b_data = lstm_b->data<T>();
    const T* atten_w_data = atten_w->data<T>();
    const T* atten_b_data = atten_b ? atten_b->data<T>() : NULL;
    const T* atten_scalar_data = atten_scalar ? atten_scalar->data<T>() : NULL;
    const T* atten_scalar_bias_data =
        atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;

T
tensor-tang 已提交
327 328 329 330 331 332
    T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
    T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
    T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
    T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
    T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
    T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
333 334 335

    // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
T
tensor-tang 已提交
336
    math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
337 338
                                      atted_x_data, atten_b_data);

T
tensor-tang 已提交
339
    const T* cur_atten_x_data = atted_x_data;
340 341 342 343 344
    const T* cur_x_data = x_data;
    const T* prev_cell_data = NULL;
    const T* prev_hidden_data = NULL;
    T* cur_cell_out_data = cell_out_data;
    T* cur_hidden_out_data = hidden_out_data;
T
tensor-tang 已提交
345
    for (int i = 0; i < N; ++i) {
T
tensor-tang 已提交
346
      int seq_len = x_lod[0][i + 1] - x_lod[0][i];
347
      prev_cell_data = c0_data + i * D;
T
tensor-tang 已提交
348
      prev_hidden_data = h0_data ? h0_data + i * D : NULL;
349
      for (int step = 0; step < seq_len; ++step) {
T
tensor-tang 已提交
350 351
        /// 1. compute attention vector
        // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T
tensor-tang 已提交
352
        T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
T
tensor-tang 已提交
353 354 355
        // 1b. add cell bias and relu
        bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
        // 1c. fc scalar
356
        if (atten_scalar_data) {
T
tensor-tang 已提交
357
          blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
358 359 360
          bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
                       fc_out_data);
        }
T
tensor-tang 已提交
361
        // 1d. softmax
T
tensor-tang 已提交
362
        vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
363 364 365 366
        // mul x(seq_len*M) and sum pool
        math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
                                          cur_x_data, lstm_x_data);

T
tensor-tang 已提交
367
        /// 2. compute LSTM step
368 369 370 371 372 373 374 375 376 377 378 379 380
        // lstm weight : concat[forget , input , output , tilde]
        // shape : (D + M) x (4 * D)
        // fc inputX(1xM) * weightX(M*(4D))  => 1 x 4D
        blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
        if (prev_hidden_data) {
          blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
                    prev_hidden_data, D, lstm_w_data, D4, static_cast<T>(1),
                    lstm_out_data, D4);
        }
        // since input is 1xM, so can use add bias
        blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);

        // gate act: sigmoid
381
        act_gate(D3, lstm_out_data, lstm_out_data);
382
        // candicate act: tanh
383
        act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
384 385 386 387 388

        // a = forget * prev_cell
        blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);

        // b = input * tilde
T
tensor-tang 已提交
389
        blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
390 391 392 393 394

        // cell_out = a + b
        blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);

        // state act tanh(cell_out) * output_gate
395
        act_cell(D, cur_cell_out_data, lstm_out_data);
T
tensor-tang 已提交
396
        blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
397

T
tensor-tang 已提交
398
        prev_hidden_data = cur_hidden_out_data;
399 400 401
        prev_cell_data = cur_cell_out_data;
        cur_cell_out_data = cur_cell_out_data + D;
        cur_hidden_out_data = cur_hidden_out_data + D;
T
tensor-tang 已提交
402
      }
403
      cur_x_data = cur_x_data + seq_len * M;
T
tensor-tang 已提交
404
      cur_atten_x_data = cur_atten_x_data + seq_len;
T
tensor-tang 已提交
405 406 407 408 409 410 411 412
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
413 414
REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
                  ops::AttentionLSTMOpMaker,
T
tensor-tang 已提交
415 416
                  paddle::framework::DefaultGradOpDescMaker<true>);

T
tensor-tang 已提交
417 418
REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
                       ops::AttentionLSTMKernel<double>);