fusion_lstm_op.cc 21.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
T
tensor-tang 已提交
16
#include <string>
17
#include "paddle/fluid/operators/jit/kernels.h"
T
tensor-tang 已提交
18
#include "paddle/fluid/operators/math/blas.h"
19
#include "paddle/fluid/operators/math/fc.h"
T
tensor-tang 已提交
20
#include "paddle/fluid/operators/math/sequence2batch.h"
T
tensor-tang 已提交
21

T
tensor-tang 已提交
22 23 24 25
namespace paddle {
namespace operators {

void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
26 27 28 29 30 31 32
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm");
T
tensor-tang 已提交
33

T
tensor-tang 已提交
34
  auto x_dims = ctx->GetInputDim("X");
35 36 37 38 39
  PADDLE_ENFORCE_EQ(x_dims.size(), 2,
                    platform::errors::InvalidArgument(
                        "Input(X)'s rank must be 2, but received x's rank "
                        "is:%d, x dim is:[%s]",
                        x_dims.size(), x_dims));
T
tensor-tang 已提交
40

41
  if (ctx->HasInput("H0")) {
42
    OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm");
T
tensor-tang 已提交
43 44
    auto h_dims = ctx->GetInputDim("H0");
    auto c_dims = ctx->GetInputDim("C0");
45 46 47 48 49
    PADDLE_ENFORCE_EQ(h_dims, c_dims,
                      platform::errors::InvalidArgument(
                          "The dimension of Input(H0) and Input(C0) should be "
                          "same, but received h0 dims is:[%s], c0 dims is:[%s]",
                          h_dims, c_dims));
T
tensor-tang 已提交
50 51
  }

T
tensor-tang 已提交
52 53
  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
54 55 56 57
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX's rank is:%d, WeightX dim is:[%s]",
                        wx_dims.size(), wx_dims));
T
tensor-tang 已提交
58
  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
59 60 61 62 63
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightX) "
                        "should equal to second dimension of Input(X), but "
                        "received WeightX first dim is:%d, X second dim is:%d",
                        wx_dims[0], x_dims[1]));
T
tensor-tang 已提交
64 65 66

  int frame_size = wx_dims[1] / 4;
  auto wh_dims = ctx->GetInputDim("WeightH");
67

T
tensor-tang 已提交
68
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
69 70 71 72
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH rank is:%d, WeightH dim is:[%s]",
                        wh_dims.size(), wh_dims));
T
tensor-tang 已提交
73
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
74 75 76 77 78 79
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightH) "
                        "should equal to frame size, but received WeightH "
                        "first dim is:%d, frame size is:%d.",
                        wh_dims[0], frame_size));

T
tensor-tang 已提交
80
  PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
81 82 83 84 85
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 4 * frame_size, but received WeightH "
                        "second dimension is:%d, frame size is:%d.",
                        wh_dims[1], frame_size));
T
tensor-tang 已提交
86 87

  auto b_dims = ctx->GetInputDim("Bias");
88 89 90 91 92
  PADDLE_ENFORCE_EQ(b_dims.size(), 2,
                    platform::errors::InvalidArgument(
                        "The rank of Input(Bias) should be 2, but received "
                        "Bias rank is:%d, Bias dim is:[%s]",
                        b_dims.size(), b_dims));
T
tensor-tang 已提交
93
  PADDLE_ENFORCE_EQ(b_dims[0], 1,
94 95 96 97 98
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(Bias) should be 1, but "
                        "received Bias's dimension is:[%s]",
                        b_dims));

T
tensor-tang 已提交
99 100
  if (ctx->Attrs().Get<bool>("use_peepholes")) {
    PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
101 102 103 104 105
                      platform::errors::InvalidArgument(
                          "The second dimension of Input(Bias) should be "
                          "7 * %d if enable peepholes connection, but received "
                          "Bias dim is:[%s]",
                          frame_size, b_dims));
T
tensor-tang 已提交
106 107
    ctx->SetOutputDim("CheckedCell", {2, frame_size});
  } else {
108 109 110 111 112 113
    PADDLE_ENFORCE_EQ(
        b_dims[1], 4 * frame_size,
        platform::errors::InvalidArgument(
            "The second dimension of Input(Bias) should be "
            "4 * %d if disable peepholes, but received Bias dim is:[%s]",
            frame_size, b_dims));
T
tensor-tang 已提交
114
  }
T
tensor-tang 已提交
115

T
tensor-tang 已提交
116
  framework::DDim out_dims({x_dims[0], frame_size});
T
tensor-tang 已提交
117 118
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
119 120
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
T
tensor-tang 已提交
121
  int xx_width;
T
tensor-tang 已提交
122
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
123 124 125
    xx_width = wx_dims[1];
  } else {
    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
126 127 128 129 130 131 132 133 134 135 136 137

    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
                   "fusion_lstm");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), "Output", "BatchedHidden",
                   "fusion_lstm");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedCell"), "Output", "BatchedCell",
                   "fusion_lstm");
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
                   "fusion_lstm");
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0",
                   "fusion_lstm");

T
tensor-tang 已提交
138 139 140
    ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
    ctx->SetOutputDim("BatchedHidden", out_dims);
    ctx->SetOutputDim("BatchedCell", out_dims);
T
tensor-tang 已提交
141
  }
T
tensor-tang 已提交
142 143
  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
144 145 146 147
}

framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
148 149
  return framework::OpKernelType(
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
T
tensor-tang 已提交
150 151 152
}

void FusionLSTMOpMaker::Make() {
T
tensor-tang 已提交
153
  AddInput("X",
T
tensor-tang 已提交
154
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
155
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
156 157 158 159 160 161 162 163 164
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
  AddInput("WeightX",
           "(Tensor) the learnable weights of X."
           " - The shape is (M x 4D), where M is the dim size of x, D is the "
           "hidden size. "
           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
  AddInput("WeightH",
           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
T
tensor-tang 已提交
165 166 167
           " - The shape is (D x 4D), where D is the hidden size. "
           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
  AddInput("Bias",
T
tensor-tang 已提交
168 169
           "(Tensor) the learnable weights. Almost same as LSTMOp"
           "Note: we should add the fc bias into this (1x4D) in bias."
T
tensor-tang 已提交
170 171 172 173 174 175 176 177
           "input-hidden bias weight and peephole connections weight if "
           "setting `use_peepholes` True. "
           "1. `use_peepholes = False` "
           " - The shape is (1 x 4D). "
           " - Bias = {b_c, b_i, b_f, b_o}."
           "2. `use_peepholes = True` "
           " - The shape is (1 x 7D). "
           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
T
tensor-tang 已提交
178 179 180 181 182 183 184 185 186 187 188 189
  AddInput("H0",
           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size and D is the hidden size.")
      .AsDispensable();
  AddInput("C0",
           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size. `H0` and `C0` can be NULL but only at the same time.")
      .AsDispensable();
T
tensor-tang 已提交
190
  AddOutput("Hidden",
T
tensor-tang 已提交
191
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
T
tensor-tang 已提交
192 193
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
T
tensor-tang 已提交
194
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
T
tensor-tang 已提交
195
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
196
  AddOutput("XX",
T
tensor-tang 已提交
197 198 199
            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
T
tensor-tang 已提交
200 201
            " D is the hidden size, M is the dim size of x input.")
      .AsIntermediate();
T
tensor-tang 已提交
202 203 204 205 206
  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
T
tensor-tang 已提交
207 208
  AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
      .AsIntermediate();
T
tensor-tang 已提交
209
  AddAttr<bool>("use_peepholes",
翟飞跃 已提交
210
                "(bool, default: True) "
T
tensor-tang 已提交
211 212 213
                "whether to enable diagonal/peephole connections.")
      .SetDefault(true);
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
214
                "(bool, default: False) "
T
tensor-tang 已提交
215 216
                "whether to compute reversed LSTM.")
      .SetDefault(false);
T
tensor-tang 已提交
217
  AddAttr<bool>("use_seq",
翟飞跃 已提交
218
                "(bool, default: True) "
T
tensor-tang 已提交
219 220
                "whether to use seq mode to compute.")
      .SetDefault(true);
T
tensor-tang 已提交
221 222 223 224 225 226 227 228
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
翟飞跃 已提交
229
                       "The activation for cell output, `tanh` by default.")
T
tensor-tang 已提交
230 231 232 233 234 235 236 237 238
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddComment(R"DOC(
T
tensor-tang 已提交
239 240
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
T
tensor-tang 已提交
241 242 243
)DOC");
}

T
tensor-tang 已提交
244
template <typename T>
T
tensor-tang 已提交
245
class FuisonLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
246
 public:
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
#define INIT_BASE_DEFINES                                   \
  using DeviceContext = paddle::platform::CPUDeviceContext; \
  auto* x = ctx.Input<LoDTensor>("X");                      \
  auto* h0 = ctx.Input<Tensor>("H0");                       \
  auto* c0 = ctx.Input<Tensor>("C0");                       \
  auto* wx = ctx.Input<Tensor>("WeightX");                  \
  auto* wh = ctx.Input<Tensor>("WeightH");                  \
  auto* bias = ctx.Input<Tensor>("Bias");                   \
  auto* xx = ctx.Output<LoDTensor>("XX");                   \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");       \
  auto* cell_out = ctx.Output<LoDTensor>("Cell");           \
  bool is_reverse = ctx.Attr<bool>("is_reverse");           \
  bool use_peepholes = ctx.Attr<bool>("use_peepholes");     \
  auto x_dims = x->dims();   /* T x M*/                     \
  auto wh_dims = wh->dims(); /* D x 4D*/                    \
  const int M = x_dims[1];                                  \
  const int D = wh_dims[0];                                 \
  const int D4 = wh_dims[1]

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
#define INIT_OTHER_DEFINES                                                     \
  const T* x_data = x->data<T>();                                              \
  const T* wx_data = wx->data<T>();                                            \
  const T* wh_data = wh->data<T>();                                            \
  /* diagonal weight*/                                                         \
  const T* wp_data = bias->data<T>() + D4;                                     \
  /* for peephole only*/                                                       \
  T* checked_cell_data = nullptr;                                              \
  auto place = ctx.GetPlace();                                                 \
  if (use_peepholes) {                                                         \
    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                           \
    auto* checked_cell = ctx.Output<Tensor>("CheckedCell");                    \
    checked_cell_data = checked_cell->mutable_data<T>(place);                  \
  }                                                                            \
  const jit::lstm_attr_t attr(                                                 \
      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),         \
      jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")),            \
      use_peepholes);                                                          \
  jit::lstm_t one_step;                                                        \
  one_step.wp = wp_data;                                                       \
  one_step.checked = checked_cell_data;                                        \
  auto ComputeC1H1 =                                                           \
      jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
          attr);                                                               \
  auto ComputeCtHt =                                                           \
      jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
          attr)
294 295

// Wh GEMM
T
tensor-tang 已提交
296 297 298 299
#define GEMM_WH_ADDON(bs, prev, out)                                           \
  blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
            wh_data, D4, static_cast<T>(1), out, D4)

T
tensor-tang 已提交
300
  void SeqCompute(const framework::ExecutionContext& ctx) const {
301 302
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
303
    auto x_lod = x->lod();
T
tensor-tang 已提交
304
    const int total_T = x_dims[0];
T
tensor-tang 已提交
305
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
306 307
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
    const T* c0_data = c0 ? c0->data<T>() : nullptr;
T
tensor-tang 已提交
308
    T* xx_data = xx->mutable_data<T>(place);
T
tensor-tang 已提交
309 310
    T* h_out_data = hidden_out->mutable_data<T>(place);
    T* c_out_data = cell_out->mutable_data<T>(place);
T
tensor-tang 已提交
311
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
312 313 314 315

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::FCFunctor<DeviceContext, T> fc;
    fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
B
Brian Liu 已提交
316

T
tensor-tang 已提交
317 318 319 320 321
    int xx_offset = D4;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 4;
T
tensor-tang 已提交
322 323
      h_out_data = h_out_data + offset;
      c_out_data = c_out_data + offset;
T
tensor-tang 已提交
324 325 326 327
      xx_offset = -D4;
      gate_offset = -D;
    }

328 329 330 331 332 333 334 335 336 337
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
      const T* prev_c_data = nullptr;
      const T* prev_h_data = nullptr;
      int tstart = 0;
      if (h0_data) {
        prev_h_data = h0_data + bid * D;
        prev_c_data = c0_data + bid * D;
      } else {
338 339 340
        one_step.gates = xx_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
341
        ComputeC1H1(&one_step, &attr);
342 343 344 345 346 347 348
        tstart = 1;
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
349
      }
350 351
      for (int step = tstart; step < seq_len; ++step) {
        GEMM_WH_ADDON(1, prev_h_data, xx_data);
352 353 354 355 356

        one_step.gates = xx_data;
        one_step.ct_1 = prev_c_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
357
        ComputeCtHt(&one_step, &attr);
358 359 360 361 362 363
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
364
      }
T
tensor-tang 已提交
365
    }
T
tensor-tang 已提交
366 367 368
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
369
    INIT_BASE_DEFINES;
T
tensor-tang 已提交
370
    if (x->lod()[0].size() == 2) {
371
      xx->Resize({x_dims[0], D4});
T
tensor-tang 已提交
372
      SeqCompute(ctx);
T
tensor-tang 已提交
373
      return;
T
tensor-tang 已提交
374
    }
375
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
376

T
tensor-tang 已提交
377 378 379 380 381 382 383 384 385 386 387
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
    T* xx_data = xx->mutable_data<T>(place);
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
    cell_out->mutable_data<T>(place);
T
tensor-tang 已提交
388

T
tensor-tang 已提交
389
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
T
tensor-tang 已提交
390 391
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
392
    math::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
393
    if (M > D4) {
394
      fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
T
tensor-tang 已提交
395
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
396 397
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
398
      batched_input->set_lod(xx->lod());
399 400
      fc(dev_ctx, x_dims[0], D4, M, xx_data, wx_data, batched_input_data,
         bias->data<T>());
T
tensor-tang 已提交
401 402
    }

T
tensor-tang 已提交
403 404 405 406 407 408 409
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
    reordered_c0->Resize({max_bs, D});

    int tstart = 0;
T
tensor-tang 已提交
410 411
    T* prev_h_data = nullptr;
    T* prev_c_data = nullptr;
T
tensor-tang 已提交
412 413 414 415 416 417
    if (h0) {
      // reorder h0, c0
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
      const T* h0_data = h0->data<T>();
      const T* c0_data = c0->data<T>();
T
tensor-tang 已提交
418 419
      prev_h_data = reordered_h0_data;
      prev_c_data = reordered_c0_data;
420
      size_t sz = D;
T
tensor-tang 已提交
421
      for (int i = 0; i < max_bs; ++i) {
422 423
        blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data);
        blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data);
T
tensor-tang 已提交
424 425 426 427
        reordered_h0_data += D;
        reordered_c0_data += D;
      }
    } else {
T
tensor-tang 已提交
428 429 430 431 432
      // compute without h0, c0
      T* cur_in_data = batched_input_data;
      T* cur_h_out_data = batched_h_out_data;
      T* cur_c_out_data = batched_c_out_data;
      for (int i = 0; i < max_bs; ++i) {
433 434 435
        one_step.gates = cur_in_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
436
        ComputeC1H1(&one_step, &attr);
437

T
tensor-tang 已提交
438 439 440 441 442
        cur_in_data += D4;
        cur_c_out_data += D;
        cur_h_out_data += D;
      }
      tstart = 1;
T
tensor-tang 已提交
443 444
      prev_h_data = batched_h_out_data;
      prev_c_data = batched_c_out_data;
T
tensor-tang 已提交
445
    }
446 447

    // compute kernel part
T
tensor-tang 已提交
448 449
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
T
tensor-tang 已提交
450 451 452 453
    const int offset = tstart * max_bs * D;
    batched_input_data = batched_input_data + offset * 4;
    batched_h_out_data = batched_h_out_data + offset;
    batched_c_out_data = batched_c_out_data + offset;
454 455 456 457 458 459 460 461
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
      T* cur_in_data = batched_input_data;
      T* cur_prev_c_data = prev_c_data;
      T* cur_c_out_data = batched_c_out_data;
      T* cur_h_out_data = batched_h_out_data;
      for (int i = 0; i < cur_bs; ++i) {
462 463 464 465
        one_step.gates = cur_in_data;
        one_step.ct_1 = cur_prev_c_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
T
tensor-tang 已提交
466
        ComputeCtHt(&one_step, &attr);
467

468 469 470 471 472
        // move one batch
        cur_in_data += D4;
        cur_prev_c_data += D;
        cur_c_out_data += D;
        cur_h_out_data += D;
T
tensor-tang 已提交
473
      }
474 475 476 477 478 479
      // move one step
      prev_c_data = batched_c_out_data;
      prev_h_data = batched_h_out_data;
      batched_c_out_data = cur_c_out_data;
      batched_h_out_data = cur_h_out_data;
      batched_input_data = cur_in_data;
T
tensor-tang 已提交
480 481 482
    }

    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
483 484 485 486
    batched_h_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_h_out, hidden_out);
    batched_c_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_c_out, cell_out);
T
tensor-tang 已提交
487
  }
T
tensor-tang 已提交
488

T
tensor-tang 已提交
489
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
490
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
491 492 493 494 495
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }
T
tensor-tang 已提交
496 497

#undef GEMM_WH_ADDON
498 499
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
500 501 502 503 504 505
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
506
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker);
T
tensor-tang 已提交
507

T
tensor-tang 已提交
508 509
REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
                       ops::FuisonLSTMKernel<double>);