fusion_lstm_op.cc 24.1 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
16

T
tensor-tang 已提交
17
#include <string>
18

19
#include "paddle/fluid/operators/jit/kernels.h"
20
#include "paddle/phi/kernels/funcs/blas/blas.h"
21
#include "paddle/phi/kernels/funcs/fc_functor.h"
F
Feiyu Chan 已提交
22
#include "paddle/phi/kernels/funcs/sequence2batch.h"
23 24 25
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
T
tensor-tang 已提交
26

T
tensor-tang 已提交
27 28 29 30
namespace paddle {
namespace operators {

void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
31 32 33 34 35 36 37
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm");
  OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm");
T
tensor-tang 已提交
38

T
tensor-tang 已提交
39
  auto x_dims = ctx->GetInputDim("X");
40 41
  PADDLE_ENFORCE_EQ(x_dims.size(),
                    2,
42 43 44
                    platform::errors::InvalidArgument(
                        "Input(X)'s rank must be 2, but received x's rank "
                        "is:%d, x dim is:[%s]",
45 46
                        x_dims.size(),
                        x_dims));
T
tensor-tang 已提交
47

48
  if (ctx->HasInput("H0")) {
49
    OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm");
T
tensor-tang 已提交
50 51
    auto h_dims = ctx->GetInputDim("H0");
    auto c_dims = ctx->GetInputDim("C0");
52 53
    PADDLE_ENFORCE_EQ(h_dims,
                      c_dims,
54 55 56
                      platform::errors::InvalidArgument(
                          "The dimension of Input(H0) and Input(C0) should be "
                          "same, but received h0 dims is:[%s], c0 dims is:[%s]",
57 58
                          h_dims,
                          c_dims));
T
tensor-tang 已提交
59 60
  }

T
tensor-tang 已提交
61
  auto wx_dims = ctx->GetInputDim("WeightX");
62 63
  PADDLE_ENFORCE_EQ(wx_dims.size(),
                    2,
64 65 66
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX's rank is:%d, WeightX dim is:[%s]",
67 68 69 70
                        wx_dims.size(),
                        wx_dims));
  PADDLE_ENFORCE_EQ(wx_dims[0],
                    x_dims[1],
71 72 73 74
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightX) "
                        "should equal to second dimension of Input(X), but "
                        "received WeightX first dim is:%d, X second dim is:%d",
75 76
                        wx_dims[0],
                        x_dims[1]));
T
tensor-tang 已提交
77 78 79

  int frame_size = wx_dims[1] / 4;
  auto wh_dims = ctx->GetInputDim("WeightH");
80

81 82
  PADDLE_ENFORCE_EQ(wh_dims.size(),
                    2,
83 84 85
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH rank is:%d, WeightH dim is:[%s]",
86 87 88 89
                        wh_dims.size(),
                        wh_dims));
  PADDLE_ENFORCE_EQ(wh_dims[0],
                    frame_size,
90 91 92 93
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightH) "
                        "should equal to frame size, but received WeightH "
                        "first dim is:%d, frame size is:%d.",
94 95
                        wh_dims[0],
                        frame_size));
96

97 98
  PADDLE_ENFORCE_EQ(wh_dims[1],
                    4 * frame_size,
99 100 101 102
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 4 * frame_size, but received WeightH "
                        "second dimension is:%d, frame size is:%d.",
103 104
                        wh_dims[1],
                        frame_size));
T
tensor-tang 已提交
105 106

  auto b_dims = ctx->GetInputDim("Bias");
107 108
  PADDLE_ENFORCE_EQ(b_dims.size(),
                    2,
109 110 111
                    platform::errors::InvalidArgument(
                        "The rank of Input(Bias) should be 2, but received "
                        "Bias rank is:%d, Bias dim is:[%s]",
112 113 114 115
                        b_dims.size(),
                        b_dims));
  PADDLE_ENFORCE_EQ(b_dims[0],
                    1,
116 117 118 119 120
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(Bias) should be 1, but "
                        "received Bias's dimension is:[%s]",
                        b_dims));

T
tensor-tang 已提交
121
  if (ctx->Attrs().Get<bool>("use_peepholes")) {
122 123
    PADDLE_ENFORCE_EQ(b_dims[1],
                      7 * frame_size,
124 125 126 127
                      platform::errors::InvalidArgument(
                          "The second dimension of Input(Bias) should be "
                          "7 * %d if enable peepholes connection, but received "
                          "Bias dim is:[%s]",
128 129
                          frame_size,
                          b_dims));
T
tensor-tang 已提交
130 131
    ctx->SetOutputDim("CheckedCell", {2, frame_size});
  } else {
132
    PADDLE_ENFORCE_EQ(
133 134
        b_dims[1],
        4 * frame_size,
135 136 137
        platform::errors::InvalidArgument(
            "The second dimension of Input(Bias) should be "
            "4 * %d if disable peepholes, but received Bias dim is:[%s]",
138 139
            frame_size,
            b_dims));
T
tensor-tang 已提交
140
  }
T
tensor-tang 已提交
141

T
tensor-tang 已提交
142
  framework::DDim out_dims({x_dims[0], frame_size});
T
tensor-tang 已提交
143 144
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
145 146
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");
T
tensor-tang 已提交
147
  int xx_width;
T
tensor-tang 已提交
148
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
149 150 151
    xx_width = wx_dims[1];
  } else {
    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
152

153 154 155
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"),
                   "Output",
                   "BatchedInput",
156
                   "fusion_lstm");
157 158 159
    OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"),
                   "Output",
                   "BatchedHidden",
160
                   "fusion_lstm");
161 162 163 164 165 166
    OP_INOUT_CHECK(
        ctx->HasOutput("BatchedCell"), "Output", "BatchedCell", "fusion_lstm");
    OP_INOUT_CHECK(
        ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_lstm");
    OP_INOUT_CHECK(
        ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0", "fusion_lstm");
167

T
tensor-tang 已提交
168 169 170
    ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
    ctx->SetOutputDim("BatchedHidden", out_dims);
    ctx->SetOutputDim("BatchedCell", out_dims);
T
tensor-tang 已提交
171
  }
T
tensor-tang 已提交
172 173
  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
174 175 176 177
}

framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
178 179 180
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
  if (this->CanMKLDNNBeUsed(ctx, data_type)) {
J
jiahongyu 已提交
181 182 183 184
    return framework::OpKernelType(data_type,
                                   ctx.GetPlace(),
                                   framework::DataLayout::kMKLDNN,
                                   framework::LibraryType::kMKLDNN);
185 186
  }
#endif
J
jiahongyu 已提交
187
  return framework::OpKernelType(data_type, ctx.GetPlace());
T
tensor-tang 已提交
188 189 190
}

void FusionLSTMOpMaker::Make() {
T
tensor-tang 已提交
191
  AddInput("X",
T
tensor-tang 已提交
192
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
193
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
194 195 196 197 198 199 200 201 202
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
  AddInput("WeightX",
           "(Tensor) the learnable weights of X."
           " - The shape is (M x 4D), where M is the dim size of x, D is the "
           "hidden size. "
           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
  AddInput("WeightH",
           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
T
tensor-tang 已提交
203 204 205
           " - The shape is (D x 4D), where D is the hidden size. "
           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
  AddInput("Bias",
T
tensor-tang 已提交
206 207
           "(Tensor) the learnable weights. Almost same as LSTMOp"
           "Note: we should add the fc bias into this (1x4D) in bias."
T
tensor-tang 已提交
208 209 210 211 212 213 214 215
           "input-hidden bias weight and peephole connections weight if "
           "setting `use_peepholes` True. "
           "1. `use_peepholes = False` "
           " - The shape is (1 x 4D). "
           " - Bias = {b_c, b_i, b_f, b_o}."
           "2. `use_peepholes = True` "
           " - The shape is (1 x 7D). "
           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
T
tensor-tang 已提交
216 217 218 219 220 221 222 223 224 225 226 227
  AddInput("H0",
           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size and D is the hidden size.")
      .AsDispensable();
  AddInput("C0",
           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size. `H0` and `C0` can be NULL but only at the same time.")
      .AsDispensable();
T
tensor-tang 已提交
228
  AddOutput("Hidden",
T
tensor-tang 已提交
229
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
T
tensor-tang 已提交
230 231
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
T
tensor-tang 已提交
232
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
T
tensor-tang 已提交
233
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
234
  AddOutput("XX",
T
tensor-tang 已提交
235 236 237
            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
T
tensor-tang 已提交
238 239
            " D is the hidden size, M is the dim size of x input.")
      .AsIntermediate();
T
tensor-tang 已提交
240 241 242 243 244
  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
T
tensor-tang 已提交
245 246
  AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.")
      .AsIntermediate();
T
tensor-tang 已提交
247
  AddAttr<bool>("use_peepholes",
翟飞跃 已提交
248
                "(bool, default: True) "
T
tensor-tang 已提交
249 250 251
                "whether to enable diagonal/peephole connections.")
      .SetDefault(true);
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
252
                "(bool, default: False) "
T
tensor-tang 已提交
253 254
                "whether to compute reversed LSTM.")
      .SetDefault(false);
T
tensor-tang 已提交
255
  AddAttr<bool>("use_seq",
翟飞跃 已提交
256
                "(bool, default: True) "
T
tensor-tang 已提交
257 258
                "whether to use seq mode to compute.")
      .SetDefault(true);
T
tensor-tang 已提交
259 260 261 262 263 264 265 266
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
翟飞跃 已提交
267
                       "The activation for cell output, `tanh` by default.")
T
tensor-tang 已提交
268 269 270 271 272 273 274 275
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
276 277 278
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
279 280 281 282 283
  AddAttr<std::string>(
      "mkldnn_data_type",
      "(string, default \"float32\"). Data type of mkldnn kernel")
      .SetDefault("float32")
      .InEnum({"float32", "int8", "bfloat16"});
284 285 286 287 288 289 290 291 292 293 294 295
  AddAttr<float>("Scale_data",
                 "Scale to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(1.0f);
  AddAttr<float>("Shift_data",
                 "Shift to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(0.0f);
  AddAttr<std::vector<float>>("Scale_weights",
                              "Scale_weights to be used for int8 weights data."
                              "Only used with MKL-DNN INT8.")
      .SetDefault({1.0f});
296 297 298 299
  AddAttr<bool>("force_fp32_output",
                "(bool, default false) Force INT8 kernel output FP32, only "
                "used in MKL-DNN INT8")
      .SetDefault(false);
T
tensor-tang 已提交
300
  AddComment(R"DOC(
T
tensor-tang 已提交
301 302
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
T
tensor-tang 已提交
303 304 305
)DOC");
}

T
tensor-tang 已提交
306
template <typename T>
T
tensor-tang 已提交
307
class FuisonLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
308
 public:
L
Leo Chen 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
#define INIT_BASE_DEFINES                               \
  using DeviceContext = phi::CPUContext;                \
  auto* x = ctx.Input<LoDTensor>("X");                  \
  auto* h0 = ctx.Input<Tensor>("H0");                   \
  auto* c0 = ctx.Input<Tensor>("C0");                   \
  auto* wx = ctx.Input<Tensor>("WeightX");              \
  auto* wh = ctx.Input<Tensor>("WeightH");              \
  auto* bias = ctx.Input<Tensor>("Bias");               \
  auto* xx = ctx.Output<LoDTensor>("XX");               \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");   \
  auto* cell_out = ctx.Output<LoDTensor>("Cell");       \
  bool is_reverse = ctx.Attr<bool>("is_reverse");       \
  bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
  auto x_dims = x->dims();   /* T x M*/                 \
  auto wh_dims = wh->dims(); /* D x 4D*/                \
  const int M = x_dims[1];                              \
  const int D = wh_dims[0];                             \
326 327
  const int D4 = wh_dims[1]

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
#define INIT_OTHER_DEFINES                                                     \
  const T* x_data = x->data<T>();                                              \
  const T* wx_data = wx->data<T>();                                            \
  const T* wh_data = wh->data<T>();                                            \
  /* diagonal weight*/                                                         \
  const T* wp_data = bias->data<T>() + D4;                                     \
  /* for peephole only*/                                                       \
  T* checked_cell_data = nullptr;                                              \
  auto place = ctx.GetPlace();                                                 \
  if (use_peepholes) {                                                         \
    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                           \
    auto* checked_cell = ctx.Output<Tensor>("CheckedCell");                    \
    checked_cell_data = checked_cell->mutable_data<T>(place);                  \
  }                                                                            \
  const jit::lstm_attr_t attr(                                                 \
343 344
      D,                                                                       \
      jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),            \
345 346 347 348 349 350 351 352 353 354 355 356
      jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")),            \
      use_peepholes);                                                          \
  jit::lstm_t one_step;                                                        \
  one_step.wp = wp_data;                                                       \
  one_step.checked = checked_cell_data;                                        \
  auto ComputeC1H1 =                                                           \
      jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
          attr);                                                               \
  auto ComputeCtHt =                                                           \
      jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
          attr)
357 358

// Wh GEMM
359 360 361 362 363 364 365 366 367 368 369 370 371 372
#define GEMM_WH_ADDON(bs, prev, out) \
  blas.GEMM(CblasNoTrans,            \
            CblasNoTrans,            \
            bs,                      \
            D4,                      \
            D,                       \
            static_cast<T>(1),       \
            prev,                    \
            D,                       \
            wh_data,                 \
            D4,                      \
            static_cast<T>(1),       \
            out,                     \
            D4)
T
tensor-tang 已提交
373

T
tensor-tang 已提交
374
  void SeqCompute(const framework::ExecutionContext& ctx) const {
375 376
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
377
    auto x_lod = x->lod();
T
tensor-tang 已提交
378
    const int total_T = x_dims[0];
T
tensor-tang 已提交
379
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
380 381
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
    const T* c0_data = c0 ? c0->data<T>() : nullptr;
T
tensor-tang 已提交
382
    T* xx_data = xx->mutable_data<T>(place);
T
tensor-tang 已提交
383 384
    T* h_out_data = hidden_out->mutable_data<T>(place);
    T* c_out_data = cell_out->mutable_data<T>(place);
385
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
386 387

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
388
    phi::funcs::FCFunctor<DeviceContext, T> fc;
389
    fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
B
Brian Liu 已提交
390

T
tensor-tang 已提交
391 392 393 394 395
    int xx_offset = D4;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 4;
T
tensor-tang 已提交
396 397
      h_out_data = h_out_data + offset;
      c_out_data = c_out_data + offset;
T
tensor-tang 已提交
398 399 400 401
      xx_offset = -D4;
      gate_offset = -D;
    }

402 403 404 405 406 407 408 409 410 411
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
      const T* prev_c_data = nullptr;
      const T* prev_h_data = nullptr;
      int tstart = 0;
      if (h0_data) {
        prev_h_data = h0_data + bid * D;
        prev_c_data = c0_data + bid * D;
      } else {
412 413 414
        one_step.gates = xx_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
415
        ComputeC1H1(&one_step, &attr);
416 417 418 419 420 421 422
        tstart = 1;
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
423
      }
424 425
      for (int step = tstart; step < seq_len; ++step) {
        GEMM_WH_ADDON(1, prev_h_data, xx_data);
426 427 428 429 430

        one_step.gates = xx_data;
        one_step.ct_1 = prev_c_data;
        one_step.ct = c_out_data;
        one_step.ht = h_out_data;
431
        ComputeCtHt(&one_step, &attr);
432 433 434 435 436 437
        // move one step
        prev_h_data = h_out_data;
        prev_c_data = c_out_data;
        xx_data = xx_data + xx_offset;
        h_out_data = h_out_data + gate_offset;
        c_out_data = c_out_data + gate_offset;
T
tensor-tang 已提交
438
      }
T
tensor-tang 已提交
439
    }
T
tensor-tang 已提交
440 441 442
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
443
    INIT_BASE_DEFINES;
T
tensor-tang 已提交
444
    if (x->lod()[0].size() == 2) {
445
      xx->Resize({x_dims[0], D4});
T
tensor-tang 已提交
446
      SeqCompute(ctx);
T
tensor-tang 已提交
447
      return;
T
tensor-tang 已提交
448
    }
449
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
450

T
tensor-tang 已提交
451 452 453 454 455 456 457 458 459 460 461
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
    T* xx_data = xx->mutable_data<T>(place);
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
    cell_out->mutable_data<T>(place);
T
tensor-tang 已提交
462

F
Feiyu Chan 已提交
463
    phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
T
tensor-tang 已提交
464
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
465
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
466
    phi::funcs::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
467
    if (M > D4) {
468
      fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
T
tensor-tang 已提交
469
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
470 471
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
472
      batched_input->set_lod(xx->lod());
473 474 475 476 477 478 479
      fc(dev_ctx,
         x_dims[0],
         D4,
         M,
         xx_data,
         wx_data,
         batched_input_data,
480
         bias->data<T>());
T
tensor-tang 已提交
481 482
    }

T
tensor-tang 已提交
483 484 485 486 487 488 489
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
    reordered_c0->Resize({max_bs, D});

    int tstart = 0;
T
tensor-tang 已提交
490 491
    T* prev_h_data = nullptr;
    T* prev_c_data = nullptr;
T
tensor-tang 已提交
492 493 494 495 496 497
    if (h0) {
      // reorder h0, c0
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
      const T* h0_data = h0->data<T>();
      const T* c0_data = c0->data<T>();
T
tensor-tang 已提交
498 499
      prev_h_data = reordered_h0_data;
      prev_c_data = reordered_c0_data;
500
      size_t sz = D;
T
tensor-tang 已提交
501
      for (int i = 0; i < max_bs; ++i) {
502 503
        blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data);
        blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data);
T
tensor-tang 已提交
504 505 506 507
        reordered_h0_data += D;
        reordered_c0_data += D;
      }
    } else {
T
tensor-tang 已提交
508 509 510 511 512
      // compute without h0, c0
      T* cur_in_data = batched_input_data;
      T* cur_h_out_data = batched_h_out_data;
      T* cur_c_out_data = batched_c_out_data;
      for (int i = 0; i < max_bs; ++i) {
513 514 515
        one_step.gates = cur_in_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
516
        ComputeC1H1(&one_step, &attr);
517

T
tensor-tang 已提交
518 519 520 521 522
        cur_in_data += D4;
        cur_c_out_data += D;
        cur_h_out_data += D;
      }
      tstart = 1;
T
tensor-tang 已提交
523 524
      prev_h_data = batched_h_out_data;
      prev_c_data = batched_c_out_data;
T
tensor-tang 已提交
525
    }
526 527

    // compute kernel part
T
tensor-tang 已提交
528 529
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
T
tensor-tang 已提交
530 531 532 533
    const int offset = tstart * max_bs * D;
    batched_input_data = batched_input_data + offset * 4;
    batched_h_out_data = batched_h_out_data + offset;
    batched_c_out_data = batched_c_out_data + offset;
534 535 536 537 538 539 540 541
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
      T* cur_in_data = batched_input_data;
      T* cur_prev_c_data = prev_c_data;
      T* cur_c_out_data = batched_c_out_data;
      T* cur_h_out_data = batched_h_out_data;
      for (int i = 0; i < cur_bs; ++i) {
542 543 544 545
        one_step.gates = cur_in_data;
        one_step.ct_1 = cur_prev_c_data;
        one_step.ct = cur_c_out_data;
        one_step.ht = cur_h_out_data;
T
tensor-tang 已提交
546
        ComputeCtHt(&one_step, &attr);
547

548 549 550 551 552
        // move one batch
        cur_in_data += D4;
        cur_prev_c_data += D;
        cur_c_out_data += D;
        cur_h_out_data += D;
T
tensor-tang 已提交
553
      }
554 555 556 557 558 559
      // move one step
      prev_c_data = batched_c_out_data;
      prev_h_data = batched_h_out_data;
      batched_c_out_data = cur_c_out_data;
      batched_h_out_data = cur_h_out_data;
      batched_input_data = cur_in_data;
T
tensor-tang 已提交
560 561
    }

F
Feiyu Chan 已提交
562
    phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
563 564 565 566
    batched_h_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_h_out, hidden_out);
    batched_c_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_c_out, cell_out);
T
tensor-tang 已提交
567
  }
T
tensor-tang 已提交
568

T
tensor-tang 已提交
569
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
570
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
571 572 573 574 575
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }
T
tensor-tang 已提交
576 577

#undef GEMM_WH_ADDON
578 579
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
580 581 582 583 584 585
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
586
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker);
T
tensor-tang 已提交
587

588 589
REGISTER_OP_CPU_KERNEL(fusion_lstm,
                       ops::FuisonLSTMKernel<float>,
T
tensor-tang 已提交
590
                       ops::FuisonLSTMKernel<double>);