fusion_lstm_op.cc 23.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
16

T
tensor-tang 已提交
17
#include <string>
18

19
#include "paddle/fluid/operators/jit/kernels.h"
20
#include "paddle/phi/kernels/funcs/blas/blas.h"
21
#include "paddle/phi/kernels/funcs/fc_functor.h"
F
Feiyu Chan 已提交
22
#include "paddle/phi/kernels/funcs/sequence2batch.h"
T
tensor-tang 已提交
23

T
tensor-tang 已提交
24 25 26 27
namespace paddle {
namespace operators {

void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
28 29 30 31 32 33 34
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm");
T
tensor-tang 已提交
35

T
tensor-tang 已提交
36
  auto x_dims = ctx->GetInputDim("X");
37 38
  PADDLE_ENFORCE_EQ(x_dims.size(),
                    2,
39 40 41
                    platform::errors::InvalidArgument(
                        "Input(X)'s rank must be 2, but received x's rank "
                        "is:%d, x dim is:[%s]",
42 43
                        x_dims.size(),
                        x_dims));
T
tensor-tang 已提交
44

45
  if (ctx->HasInput("H0")) {
46
    OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm");
T
tensor-tang 已提交
47 48
    auto h_dims = ctx->GetInputDim("H0");
    auto c_dims = ctx->GetInputDim("C0");
49 50
    PADDLE_ENFORCE_EQ(h_dims,
                      c_dims,
51 52 53
                      platform::errors::InvalidArgument(
                          "The dimension of Input(H0) and Input(C0) should be "
                          "same, but received h0 dims is:[%s], c0 dims is:[%s]",
54 55
                          h_dims,
                          c_dims));
T
tensor-tang 已提交
56 57
  }

T
tensor-tang 已提交
58
  auto wx_dims = ctx->GetInputDim("WeightX");
59 60
  PADDLE_ENFORCE_EQ(wx_dims.size(),
                    2,
61 62 63
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX's rank is:%d, WeightX dim is:[%s]",
64 65 66 67
                        wx_dims.size(),
                        wx_dims));
  PADDLE_ENFORCE_EQ(wx_dims[0],
                    x_dims[1],
68 69 70 71
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightX) "
                        "should equal to second dimension of Input(X), but "
                        "received WeightX first dim is:%d, X second dim is:%d",
72 73
                        wx_dims[0],
                        x_dims[1]));
T
tensor-tang 已提交
74 75 76

  int frame_size = wx_dims[1] / 4;
  auto wh_dims = ctx->GetInputDim("WeightH");
77

78 79
  PADDLE_ENFORCE_EQ(wh_dims.size(),
                    2,
80 81 82
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH rank is:%d, WeightH dim is:[%s]",
83 84 85 86
                        wh_dims.size(),
                        wh_dims));
  PADDLE_ENFORCE_EQ(wh_dims[0],
                    frame_size,
87 88 89 90
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightH) "
                        "should equal to frame size, but received WeightH "
                        "first dim is:%d, frame size is:%d.",
91 92
                        wh_dims[0],
                        frame_size));
93

94 95
  PADDLE_ENFORCE_EQ(wh_dims[1],
                    4 * frame_size,
96 97 98 99
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 4 * frame_size, but received WeightH "
                        "second dimension is:%d, frame size is:%d.",
100 101
                        wh_dims[1],
                        frame_size));
T
tensor-tang 已提交
102 103

  auto b_dims = ctx->GetInputDim("Bias");
104 105
  PADDLE_ENFORCE_EQ(b_dims.size(),
                    2,
106 107 108
                    platform::errors::InvalidArgument(
                        "The rank of Input(Bias) should be 2, but received "
                        "Bias rank is:%d, Bias dim is:[%s]",
109 110 111 112
                        b_dims.size(),
                        b_dims));
  PADDLE_ENFORCE_EQ(b_dims[0],
                    1,
113 114 115 116 117
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(Bias) should be 1, but "
                        "received Bias's dimension is:[%s]",
                        b_dims));

T
tensor-tang 已提交
118
  if (ctx->Attrs().Get<bool>("use_peepholes")) {
119 120
    PADDLE_ENFORCE_EQ(b_dims[1],
                      7 * frame_size,
121 122 123 124
                      platform::errors::InvalidArgument(
                          "The second dimension of Input(Bias) should be "
                          "7 * %d if enable peepholes connection, but received "
                          "Bias dim is:[%s]",
125 126
                          frame_size,
                          b_dims));
T
tensor-tang 已提交
127 128
    ctx->SetOutputDim("CheckedCell", {2, frame_size});
  } else {
129
    PADDLE_ENFORCE_EQ(
130 131
        b_dims[1],
        4 * frame_size,
132 133 134
        platform::errors::InvalidArgument(
            "The second dimension of Input(Bias) should be "
            "4 * %d if disable peepholes, but received Bias dim is:[%s]",
135 136
            frame_size,
            b_dims));
T
tensor-tang 已提交
137
  }
T
tensor-tang 已提交
138

T
tensor-tang 已提交
139
  framework::DDim out_dims({x_dims[0], frame_size});
T
tensor-tang 已提交
140 141
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
142 143
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
T
tensor-tang 已提交
144
  int xx_width;
T
tensor-tang 已提交
145
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
146 147 148
    xx_width = wx_dims[1];
  } else {
    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
149

150 151 152
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"),
                   "Output",
                   "BatchedInput",
153
                   "fusion_lstm");
154 155 156
    OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"),
                   "Output",
                   "BatchedHidden",
157
                   "fusion_lstm");
158 159 160 161 162 163
    OP_INOUT_CHECK(
        ctx->HasOutput("BatchedCell"), "Output", "BatchedCell", "fusion_lstm");
    OP_INOUT_CHECK(
        ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_lstm");
    OP_INOUT_CHECK(
        ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0", "fusion_lstm");
164

T
tensor-tang 已提交
165 166 167
    ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
    ctx->SetOutputDim("BatchedHidden", out_dims);
    ctx->SetOutputDim("BatchedCell", out_dims);
T
tensor-tang 已提交
168
  }
T
tensor-tang 已提交
169 170
  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
171 172 173 174
}

framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
175
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
176
  return framework::OpKernelType(data_type, ctx.GetPlace());
T
tensor-tang 已提交
177 178 179
}

void FusionLSTMOpMaker::Make() {
T
tensor-tang 已提交
180
  AddInput("X",
T
tensor-tang 已提交
181
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
182
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
183 184 185 186 187 188 189 190 191
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
  AddInput("WeightX",
           "(Tensor) the learnable weights of X."
           " - The shape is (M x 4D), where M is the dim size of x, D is the "
           "hidden size. "
           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
  AddInput("WeightH",
           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
T
tensor-tang 已提交
192 193 194
           " - The shape is (D x 4D), where D is the hidden size. "
           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
  AddInput("Bias",
T
tensor-tang 已提交
195 196
           "(Tensor) the learnable weights. Almost same as LSTMOp"
           "Note: we should add the fc bias into this (1x4D) in bias."
T
tensor-tang 已提交
197 198 199 200 201 202 203 204
           "input-hidden bias weight and peephole connections weight if "
           "setting `use_peepholes` True. "
           "1. `use_peepholes = False` "
           " - The shape is (1 x 4D). "
           " - Bias = {b_c, b_i, b_f, b_o}."
           "2. `use_peepholes = True` "
           " - The shape is (1 x 7D). "
           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
T
tensor-tang 已提交
205 206 207 208 209 210 211 212 213 214 215 216
  AddInput("H0",
           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size and D is the hidden size.")
      .AsDispensable();
  AddInput("C0",
           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size. `H0` and `C0` can be NULL but only at the same time.")
      .AsDispensable();
T
tensor-tang 已提交
217
  AddOutput("Hidden",
T
tensor-tang 已提交
218
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
T
tensor-tang 已提交
219 220
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
T
tensor-tang 已提交
221
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
T
tensor-tang 已提交
222
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
223
  AddOutput("XX",
T
tensor-tang 已提交
224 225 226
            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
T
tensor-tang 已提交
227 228
            " D is the hidden size, M is the dim size of x input.")
      .AsIntermediate();
T
tensor-tang 已提交
229 230 231 232 233
  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
T
tensor-tang 已提交
234 235
  AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
      .AsIntermediate();
T
tensor-tang 已提交
236
  AddAttr<bool>("use_peepholes",
翟飞跃 已提交
237
                "(bool, default: True) "
T
tensor-tang 已提交
238 239 240
                "whether to enable diagonal/peephole connections.")
      .SetDefault(true);
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
241
                "(bool, default: False) "
T
tensor-tang 已提交
242 243
                "whether to compute reversed LSTM.")
      .SetDefault(false);
T
tensor-tang 已提交
244
  AddAttr<bool>("use_seq",
翟飞跃 已提交
245
                "(bool, default: True) "
T
tensor-tang 已提交
246 247
                "whether to use seq mode to compute.")
      .SetDefault(true);
T
tensor-tang 已提交
248 249 250 251 252 253 254 255
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
翟飞跃 已提交
256
                       "The activation for cell output, `tanh` by default.")
T
tensor-tang 已提交
257 258 259 260 261 262 263 264
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
265 266 267
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
268 269 270 271 272
  AddAttr<std::string>(
      "mkldnn_data_type",
      "(string, default \"float32\"). Data type of mkldnn kernel")
      .SetDefault("float32")
      .InEnum({"float32", "int8", "bfloat16"});
273 274 275 276 277 278 279 280 281 282 283 284
  AddAttr<float>("Scale_data",
                 "Scale to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(1.0f);
  AddAttr<float>("Shift_data",
                 "Shift to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(0.0f);
  AddAttr<std::vector<float>>("Scale_weights",
                              "Scale_weights to be used for int8 weights data."
                              "Only used with MKL-DNN INT8.")
      .SetDefault({1.0f});
285 286 287 288
  AddAttr<bool>("force_fp32_output",
                "(bool, default false) Force INT8 kernel output FP32, only "
                "used in MKL-DNN INT8")
      .SetDefault(false);
T
tensor-tang 已提交
289
  AddComment(R"DOC(
T
tensor-tang 已提交
290 291
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
T
tensor-tang 已提交
292 293 294
)DOC");
}

T
tensor-tang 已提交
295
template <typename T>
T
tensor-tang 已提交
296
class FuisonLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
297
 public:
L
Leo Chen 已提交
298 299 300
#define INIT_BASE_DEFINES                               \
  using DeviceContext = phi::CPUContext;                \
  auto* x = ctx.Input<LoDTensor>("X");                  \
301 302 303 304 305
  auto* h0 = ctx.Input<phi::DenseTensor>("H0");         \
  auto* c0 = ctx.Input<phi::DenseTensor>("C0");         \
  auto* wx = ctx.Input<phi::DenseTensor>("WeightX");    \
  auto* wh = ctx.Input<phi::DenseTensor>("WeightH");    \
  auto* bias = ctx.Input<phi::DenseTensor>("Bias");     \
L
Leo Chen 已提交
306 307 308 309 310 311 312 313 314
  auto* xx = ctx.Output<LoDTensor>("XX");               \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");   \
  auto* cell_out = ctx.Output<LoDTensor>("Cell");       \
  bool is_reverse = ctx.Attr<bool>("is_reverse");       \
  bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
  auto x_dims = x->dims();   /* T x M*/                 \
  auto wh_dims = wh->dims(); /* D x 4D*/                \
  const int M = x_dims[1];                              \
  const int D = wh_dims[0];                             \
315 316
  const int D4 = wh_dims[1]

317 318 319 320 321 322 323 324 325 326 327
#define INIT_OTHER_DEFINES                                                     \
  const T* x_data = x->data<T>();                                              \
  const T* wx_data = wx->data<T>();                                            \
  const T* wh_data = wh->data<T>();                                            \
  /* diagonal weight*/                                                         \
  const T* wp_data = bias->data<T>() + D4;                                     \
  /* for peephole only*/                                                       \
  T* checked_cell_data = nullptr;                                              \
  auto place = ctx.GetPlace();                                                 \
  if (use_peepholes) {                                                         \
    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                           \
328
    auto* checked_cell = ctx.Output<phi::DenseTensor>("CheckedCell");          \
329 330 331
    checked_cell_data = checked_cell->mutable_data<T>(place);                  \
  }                                                                            \
  const jit::lstm_attr_t attr(                                                 \
332 333
      D,                                                                       \
      jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),            \
334 335 336 337 338 339 340 341 342 343 344 345
      jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")),            \
      use_peepholes);                                                          \
  jit::lstm_t one_step;                                                        \
  one_step.wp = wp_data;                                                       \
  one_step.checked = checked_cell_data;                                        \
  auto ComputeC1H1 =                                                           \
      jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
          attr);                                                               \
  auto ComputeCtHt =                                                           \
      jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
          attr)
346 347

// Wh GEMM
348 349 350 351 352 353 354 355 356 357 358 359 360 361
#define GEMM_WH_ADDON(bs, prev, out) \
  blas.GEMM(CblasNoTrans,            \
            CblasNoTrans,            \
            bs,                      \
            D4,                      \
            D,                       \
            static_cast<T>(1),       \
            prev,                    \
            D,                       \
            wh_data,                 \
            D4,                      \
            static_cast<T>(1),       \
            out,                     \
            D4)
T
tensor-tang 已提交
362

T
tensor-tang 已提交
363
  void SeqCompute(const framework::ExecutionContext& ctx) const {
364 365
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
366
    auto x_lod = x->lod();
T
tensor-tang 已提交
367
    const int total_T = x_dims[0];
T
tensor-tang 已提交
368
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
369 370
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
    const T* c0_data = c0 ? c0->data<T>() : nullptr;
T
tensor-tang 已提交
371
    T* xx_data = xx->mutable_data<T>(place);
T
tensor-tang 已提交
372 373
    T* h_out_data = hidden_out->mutable_data<T>(place);
    T* c_out_data = cell_out->mutable_data<T>(place);
374
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
375 376

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
377
    phi::funcs::FCFunctor<DeviceContext, T> fc;
378
    fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
B
Brian Liu 已提交
379

T
tensor-tang 已提交
380 381 382 383 384
    int xx_offset = D4;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 4;
T
tensor-tang 已提交
385 386
      h_out_data = h_out_data + offset;
      c_out_data = c_out_data + offset;
T
tensor-tang 已提交
387 388 389 390
      xx_offset = -D4;
      gate_offset = -D;
    }

391 392 393 394 395 396 397 398 399 400
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
      const T* prev_c_data = nullptr;
      const T* prev_h_data = nullptr;
      int tstart = 0;
      if (h0_data) {
        prev_h_data = h0_data + bid * D;
        prev_c_data = c0_data + bid * D;
      } else {
401 402 403
        one_step.gates = xx_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
404
        ComputeC1H1(&one_step, &attr);
405 406 407 408 409 410 411
        tstart = 1;
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
412
      }
413 414
      for (int step = tstart; step < seq_len; ++step) {
        GEMM_WH_ADDON(1, prev_h_data, xx_data);
415 416 417 418 419

        one_step.gates = xx_data;
        one_step.ct_1 = prev_c_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
420
        ComputeCtHt(&one_step, &attr);
421 422 423 424 425 426
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
427
      }
T
tensor-tang 已提交
428
    }
T
tensor-tang 已提交
429 430 431
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
432
    INIT_BASE_DEFINES;
T
tensor-tang 已提交
433
    if (x->lod()[0].size() == 2) {
434
      xx->Resize({x_dims[0], D4});
T
tensor-tang 已提交
435
      SeqCompute(ctx);
T
tensor-tang 已提交
436
      return;
T
tensor-tang 已提交
437
    }
438
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
439

440 441
    auto* reordered_h0 = ctx.Output<phi::DenseTensor>("ReorderedH0");
    auto* reordered_c0 = ctx.Output<phi::DenseTensor>("ReorderedC0");
T
tensor-tang 已提交
442 443 444 445 446 447 448 449 450
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
    T* xx_data = xx->mutable_data<T>(place);
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
    cell_out->mutable_data<T>(place);
T
tensor-tang 已提交
451

F
Feiyu Chan 已提交
452
    phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
T
tensor-tang 已提交
453
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
454
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
455
    phi::funcs::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
456
    if (M > D4) {
457
      fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
T
tensor-tang 已提交
458
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
459 460
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
461
      batched_input->set_lod(xx->lod());
462 463 464 465 466 467 468
      fc(dev_ctx,
         x_dims[0],
         D4,
         M,
         xx_data,
         wx_data,
         batched_input_data,
469
         bias->data<T>());
T
tensor-tang 已提交
470 471
    }

T
tensor-tang 已提交
472 473 474 475 476 477 478
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
    reordered_c0->Resize({max_bs, D});

    int tstart = 0;
T
tensor-tang 已提交
479 480
    T* prev_h_data = nullptr;
    T* prev_c_data = nullptr;
T
tensor-tang 已提交
481 482 483 484 485 486
    if (h0) {
      // reorder h0, c0
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
      const T* h0_data = h0->data<T>();
      const T* c0_data = c0->data<T>();
T
tensor-tang 已提交
487 488
      prev_h_data = reordered_h0_data;
      prev_c_data = reordered_c0_data;
489
      size_t sz = D;
T
tensor-tang 已提交
490
      for (int i = 0; i < max_bs; ++i) {
491 492
        blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data);
        blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data);
T
tensor-tang 已提交
493 494 495 496
        reordered_h0_data += D;
        reordered_c0_data += D;
      }
    } else {
T
tensor-tang 已提交
497 498 499 500 501
      // compute without h0, c0
      T* cur_in_data = batched_input_data;
      T* cur_h_out_data = batched_h_out_data;
      T* cur_c_out_data = batched_c_out_data;
      for (int i = 0; i < max_bs; ++i) {
502 503 504
        one_step.gates = cur_in_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
505
        ComputeC1H1(&one_step, &attr);
506

T
tensor-tang 已提交
507 508 509 510 511
        cur_in_data += D4;
        cur_c_out_data += D;
        cur_h_out_data += D;
      }
      tstart = 1;
T
tensor-tang 已提交
512 513
      prev_h_data = batched_h_out_data;
      prev_c_data = batched_c_out_data;
T
tensor-tang 已提交
514
    }
515 516

    // compute kernel part
T
tensor-tang 已提交
517 518
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
T
tensor-tang 已提交
519 520 521 522
    const int offset = tstart * max_bs * D;
    batched_input_data = batched_input_data + offset * 4;
    batched_h_out_data = batched_h_out_data + offset;
    batched_c_out_data = batched_c_out_data + offset;
523 524 525 526 527 528 529 530
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
      T* cur_in_data = batched_input_data;
      T* cur_prev_c_data = prev_c_data;
      T* cur_c_out_data = batched_c_out_data;
      T* cur_h_out_data = batched_h_out_data;
      for (int i = 0; i < cur_bs; ++i) {
531 532 533 534
        one_step.gates = cur_in_data;
        one_step.ct_1 = cur_prev_c_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
T
tensor-tang 已提交
535
        ComputeCtHt(&one_step, &attr);
536

537 538 539 540 541
        // move one batch
        cur_in_data += D4;
        cur_prev_c_data += D;
        cur_c_out_data += D;
        cur_h_out_data += D;
T
tensor-tang 已提交
542
      }
543 544 545 546 547 548
      // move one step
      prev_c_data = batched_c_out_data;
      prev_h_data = batched_h_out_data;
      batched_c_out_data = cur_c_out_data;
      batched_h_out_data = cur_h_out_data;
      batched_input_data = cur_in_data;
T
tensor-tang 已提交
549 550
    }

F
Feiyu Chan 已提交
551
    phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
552 553 554 555
    batched_h_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_h_out, hidden_out);
    batched_c_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_c_out, cell_out);
T
tensor-tang 已提交
556
  }
T
tensor-tang 已提交
557

T
tensor-tang 已提交
558
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
559
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
560 561 562 563 564
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }
T
tensor-tang 已提交
565 566

#undef GEMM_WH_ADDON
567 568
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
569 570 571 572 573 574
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
575
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker);
T
tensor-tang 已提交
576

577 578
REGISTER_OP_CPU_KERNEL(fusion_lstm,
                       ops::FuisonLSTMKernel<float>,
T
tensor-tang 已提交
579
                       ops::FuisonLSTMKernel<double>);