fusion_lstm_op.cc 24.9 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
T
tensor-tang 已提交
17
#include "paddle/fluid/operators/math/blas.h"
T
tensor-tang 已提交
18
#include "paddle/fluid/operators/math/cpu_vec.h"
19
#include "paddle/fluid/operators/math/fc_compute.h"
T
tensor-tang 已提交
20
#include "paddle/fluid/operators/math/sequence2batch.h"
T
tensor-tang 已提交
21 22
#include "paddle/fluid/platform/cpu_info.h"

T
tensor-tang 已提交
23 24 25 26
namespace paddle {
namespace operators {

void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
T
tensor-tang 已提交
27 28 29 30 31
  PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("WeightX"),
                 "Input(WeightX) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("WeightH"),
                 "Input(WeightH) of LSTM should not be null.");
T
tensor-tang 已提交
32 33 34
  PADDLE_ENFORCE(ctx->HasInput("Bias"),
                 "Input(Bias) of LSTM should not be null.");

T
tensor-tang 已提交
35 36
  PADDLE_ENFORCE(ctx->HasOutput("XX"),
                 "Output(XX) of LSTM should not be null.");
T
tensor-tang 已提交
37 38 39 40
  PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
                 "Output(Hidden) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("Cell"),
                 "Output(Cell) of LSTM should not be null.");
T
tensor-tang 已提交
41 42 43 44 45 46 47 48 49 50
  PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
                 "Output(BatchedInput) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
                 "Output(BatchedHidden) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
                 "Output(BatchedCell) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
                 "Output(ReorderedH0) of LSTM should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
                 "Output(ReorderedC0) of LSTM should not be null.");
T
tensor-tang 已提交
51

T
tensor-tang 已提交
52 53
  auto x_dims = ctx->GetInputDim("X");
  PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
T
tensor-tang 已提交
54 55 56 57 58 59 60 61 62 63 64 65

  if (ctx->HasInput("H0")) {
    PADDLE_ENFORCE(ctx->HasInput("C0"),
                   "Input(Cell) and Input(Hidden) of LSTM should not "
                   "be null at the same time.");
    auto h_dims = ctx->GetInputDim("H0");
    auto c_dims = ctx->GetInputDim("C0");
    PADDLE_ENFORCE(h_dims == c_dims,
                   "The dimension of Input(H0) and Input(C0) "
                   "should be the same.");
  }

T
tensor-tang 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79
  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
                    "The rank of Input(WeightX) should be 2.");
  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
                    "The first dimension of Input(WeightX) "
                    "should be %d.",
                    x_dims[1]);

  int frame_size = wx_dims[1] / 4;
  auto wh_dims = ctx->GetInputDim("WeightH");
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
                    "The rank of Input(WeightH) should be 2.");
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
                    "The first dimension of Input(WeightH) "
T
tensor-tang 已提交
80 81
                    "should be %d.",
                    frame_size);
T
tensor-tang 已提交
82 83
  PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
                    "The second dimension of Input(WeightH) "
T
tensor-tang 已提交
84 85 86 87 88 89 90 91
                    "should be 4 * %d.",
                    frame_size);

  auto b_dims = ctx->GetInputDim("Bias");
  PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
  PADDLE_ENFORCE_EQ(b_dims[0], 1,
                    "The first dimension of Input(Bias) should be 1.");

B
Brian Liu 已提交
92 93
  auto use_peepholes = ctx->Attrs().Get<bool>("use_peepholes");
  PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size,
T
tensor-tang 已提交
94
                    "The second dimension of Input(Bias) should be "
B
Brian Liu 已提交
95 96 97
                    "7 * %d if enable peepholes connection or"
                    "4 * %d if disable peepholes",
                    frame_size, frame_size);
T
tensor-tang 已提交
98

T
tensor-tang 已提交
99
  framework::DDim out_dims({x_dims[0], frame_size});
T
tensor-tang 已提交
100 101
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->SetOutputDim("Cell", out_dims);
T
tensor-tang 已提交
102 103 104
  ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
  ctx->SetOutputDim("BatchedHidden", out_dims);
  ctx->SetOutputDim("BatchedCell", out_dims);
T
tensor-tang 已提交
105 106 107
  ctx->ShareLoD("X", "Hidden");
  ctx->ShareLoD("X", "Cell");

T
tensor-tang 已提交
108
  int xx_width;
T
tensor-tang 已提交
109
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
110 111 112 113
    xx_width = wx_dims[1];
  } else {
    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
  }
T
tensor-tang 已提交
114 115
  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
116 117 118 119 120
}

framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  return framework::OpKernelType(
T
tensor-tang 已提交
121
      framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
T
tensor-tang 已提交
122 123 124 125
      ctx.device_context());
}

void FusionLSTMOpMaker::Make() {
T
tensor-tang 已提交
126
  AddInput("X",
T
tensor-tang 已提交
127
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
128
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
129 130 131 132 133 134 135 136 137
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
  AddInput("WeightX",
           "(Tensor) the learnable weights of X."
           " - The shape is (M x 4D), where M is the dim size of x, D is the "
           "hidden size. "
           " - Weight = {W_cx, W_ix, W_fx, W_ox}");
  AddInput("WeightH",
           "(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
T
tensor-tang 已提交
138 139 140
           " - The shape is (D x 4D), where D is the hidden size. "
           " - Weight = {W_ch, W_ih, W_fh, W_oh}");
  AddInput("Bias",
T
tensor-tang 已提交
141 142
           "(Tensor) the learnable weights. Almost same as LSTMOp"
           "Note: we should add the fc bias into this (1x4D) in bias."
T
tensor-tang 已提交
143 144 145 146 147 148 149 150
           "input-hidden bias weight and peephole connections weight if "
           "setting `use_peepholes` True. "
           "1. `use_peepholes = False` "
           " - The shape is (1 x 4D). "
           " - Bias = {b_c, b_i, b_f, b_o}."
           "2. `use_peepholes = True` "
           " - The shape is (1 x 7D). "
           " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
T
tensor-tang 已提交
151 152 153 154 155 156 157 158 159 160 161 162
  AddInput("H0",
           "(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size and D is the hidden size.")
      .AsDispensable();
  AddInput("C0",
           "(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
           "optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size. `H0` and `C0` can be NULL but only at the same time.")
      .AsDispensable();
T
tensor-tang 已提交
163
  AddOutput("Hidden",
T
tensor-tang 已提交
164
            "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
T
tensor-tang 已提交
165 166
            "The shape is (T x D), and lod is the same with the `Input`.");
  AddOutput("Cell",
T
tensor-tang 已提交
167
            "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
T
tensor-tang 已提交
168
            "The shape is (T x D), and lod is the same with the `Input`.");
T
tensor-tang 已提交
169
  AddOutput("XX",
T
tensor-tang 已提交
170 171 172
            "(LoDTensor) the result after X * WeightX (size is T x 4D)"
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
T
tensor-tang 已提交
173 174
            " D is the hidden size, M is the dim size of x input.")
      .AsIntermediate();
T
tensor-tang 已提交
175 176 177 178 179
  AddOutput("BatchedInput", "(LoDTensor) (T x 4D).").AsIntermediate();
  AddOutput("BatchedHidden", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate();
  AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate();
  AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate();
T
tensor-tang 已提交
180 181 182 183 184 185 186 187
  AddAttr<bool>("use_peepholes",
                "(bool, defalut: True) "
                "whether to enable diagonal/peephole connections.")
      .SetDefault(true);
  AddAttr<bool>("is_reverse",
                "(bool, defalut: False) "
                "whether to compute reversed LSTM.")
      .SetDefault(false);
T
tensor-tang 已提交
188 189 190 191
  AddAttr<bool>("use_seq",
                "(bool, defalut: True) "
                "whether to use seq mode to compute.")
      .SetDefault(true);
T
tensor-tang 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
  AddAttr<std::string>("gate_activation",
                       "(string, default: sigmoid)"
                       "The activation for input gate, forget gate and output "
                       "gate, `sigmoid` by default.")
      .SetDefault("sigmoid")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("cell_activation",
                       "(string, default: tanh)"
                       "The activation for cell output, `tanh` by defalut.")
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddAttr<std::string>("candidate_activation",
                       "(string, default: tanh)"
                       "The activation for candidate hidden state, "
                       "`tanh` by default.")
      .SetDefault("tanh")
      .InEnum({"sigmoid", "tanh", "relu", "identity"});
  AddComment(R"DOC(
T
tensor-tang 已提交
210 211
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
T
tensor-tang 已提交
212 213 214
)DOC");
}

T
tensor-tang 已提交
215
template <typename T>
T
tensor-tang 已提交
216
class FuisonLSTMKernel : public framework::OpKernel<T> {
T
tensor-tang 已提交
217
 public:
T
tensor-tang 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
#define INIT_VEC_FUNC                                                          \
  std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
  auto& act_gate_str = ctx.Attr<std::string>("gate_activation");               \
  auto& act_cell_str = ctx.Attr<std::string>("cell_activation");               \
  auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");          \
  if (platform::jit::MayIUse(platform::jit::avx)) {                            \
    math::VecActivations<T, platform::jit::avx> act_functor;                   \
    act_gate = act_functor(act_gate_str);                                      \
    act_cell = act_functor(act_cell_str);                                      \
    act_cand = act_functor(act_cand_str);                                      \
  } else {                                                                     \
    math::VecActivations<T, platform::jit::isa_any> act_functor;               \
    act_gate = act_functor(act_gate_str);                                      \
    act_cell = act_functor(act_cell_str);                                      \
    act_cand = act_functor(act_cand_str);                                      \
  }

B
Brian Liu 已提交
235 236 237 238 239 240 241 242 243 244 245
#define INIT_BASE_INPUT_OUTPUT                          \
  auto* x = ctx.Input<LoDTensor>("X");                  \
  auto* h0 = ctx.Input<Tensor>("H0");                   \
  auto* c0 = ctx.Input<Tensor>("C0");                   \
  auto* wx = ctx.Input<Tensor>("WeightX");              \
  auto* wh = ctx.Input<Tensor>("WeightH");              \
  auto* bias = ctx.Input<Tensor>("Bias");               \
  auto* xx = ctx.Output<LoDTensor>("XX");               \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");   \
  auto* cell_out = ctx.Output<LoDTensor>("Cell");       \
  bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
T
tensor-tang 已提交
246 247 248 249 250 251 252 253 254 255 256
  bool is_reverse = ctx.Attr<bool>("is_reverse");

#define INIT_BASE_SIZES                  \
  auto x_dims = x->dims();   /* T x M*/  \
  auto wh_dims = wh->dims(); /* D x 4D*/ \
  const int M = x_dims[1];               \
  const int D = wh_dims[0];              \
  const int D2 = D * 2;                  \
  const int D3 = D * 3;                  \
  const int D4 = wh_dims[1];

T
tensor-tang 已提交
257 258
  void SeqCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
259 260 261
    INIT_BASE_INPUT_OUTPUT
    INIT_BASE_SIZES
    INIT_VEC_FUNC
T
tensor-tang 已提交
262

T
tensor-tang 已提交
263
    auto x_lod = x->lod();
T
tensor-tang 已提交
264
    const int total_T = x_dims[0];
T
tensor-tang 已提交
265
    const int N = x_lod[0].size() - 1;  // batch size
T
tensor-tang 已提交
266 267

    const T* x_data = x->data<T>();
T
tensor-tang 已提交
268 269
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
    const T* c0_data = c0 ? c0->data<T>() : nullptr;
B
Brian Liu 已提交
270 271
    const T* bias_data = bias->data<T>();
    const T* wc_data = bias_data + D4;  // w_ic, w_fc, w_oc
T
tensor-tang 已提交
272
    const T* wx_data = wx->data<T>();
T
tensor-tang 已提交
273
    const T* wh_data = wh->data<T>();
B
Brian Liu 已提交
274

T
tensor-tang 已提交
275
    T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T
tensor-tang 已提交
276 277
    T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
    T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
T
tensor-tang 已提交
278

B
Brian Liu 已提交
279 280 281 282 283 284
    // use local variable
    framework::DDim check_dims({3, D});
    Tensor checked_cell;  // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
    auto checked_cell_data =
        checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());

T
tensor-tang 已提交
285
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
T
tensor-tang 已提交
286
    math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
T
tensor-tang 已提交
287
                                      xx_data, bias->data<T>());
T
tensor-tang 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    int xx_offset = D4;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 4;
      hidden_out_data = hidden_out_data + offset;
      cell_out_data = cell_out_data + offset;
      xx_offset = -D4;
      gate_offset = -D;
    }

    auto move_step = [&]() {
      xx_data = xx_data + xx_offset;
      hidden_out_data = hidden_out_data + gate_offset;
      cell_out_data = cell_out_data + gate_offset;
    };
T
tensor-tang 已提交
304 305

    for (int i = 0; i < N; ++i) {
T
tensor-tang 已提交
306 307
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
T
tensor-tang 已提交
308 309
      const T* prev_c_data = nullptr;
      const T* prev_h_data = nullptr;
B
Brian Liu 已提交
310

T
tensor-tang 已提交
311 312
      int tstart = 0;
      if (h0_data) {
T
tensor-tang 已提交
313 314
        prev_h_data = h0_data + bid * D;
        prev_c_data = c0_data + bid * D;
T
tensor-tang 已提交
315
      } else {
B
Brian Liu 已提交
316 317 318 319
        // If step == 0 and there is no initialized hidden state, that is to say
        // the H0 is zeros. Then W_h * H_t-1 can be skipped

        // ~C_t
T
tensor-tang 已提交
320
        act_cand(D, xx_data, xx_data);
B
Brian Liu 已提交
321 322 323 324 325 326 327 328
        if (use_peepholes) {
          // I_t, F_t
          act_gate(D2, xx_data + D, xx_data + D);
        } else {
          // I_t, F_t, O_t
          act_gate(D3, xx_data + D, xx_data + D);
        }
        // C_t = I_t * ~C_t
T
tensor-tang 已提交
329
        blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
B
Brian Liu 已提交
330 331 332 333 334 335 336 337 338

        if (use_peepholes) {
          // + W_oc * C_t for peephole connection
          blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
          blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
          // O_t
          act_gate(D, xx_data + D3, xx_data + D3);
        }

T
tensor-tang 已提交
339
        // hidden out= act_state(cellout) * outgate
T
tensor-tang 已提交
340
        act_cell(D, cell_out_data, xx_data + D2);
B
Brian Liu 已提交
341
        // H_t = O_t * act_state(C_t)
T
tensor-tang 已提交
342 343 344
        blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);

        // prev
T
tensor-tang 已提交
345 346
        prev_h_data = hidden_out_data;
        prev_c_data = cell_out_data;
T
tensor-tang 已提交
347

B
Brian Liu 已提交
348
        tstart = 1;
T
tensor-tang 已提交
349
        move_step();
T
tensor-tang 已提交
350
      }
B
Brian Liu 已提交
351

T
tensor-tang 已提交
352
      for (int step = tstart; step < seq_len; ++step) {
B
Brian Liu 已提交
353
        // + W_h * H_t-1
T
tensor-tang 已提交
354
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
T
tensor-tang 已提交
355
                  prev_h_data, D, wh_data, D4, static_cast<T>(1), xx_data, D4);
T
tensor-tang 已提交
356

B
Brian Liu 已提交
357
        // ~C_t
T
tensor-tang 已提交
358
        act_cand(D, xx_data, xx_data);
T
tensor-tang 已提交
359

B
Brian Liu 已提交
360 361 362 363 364 365 366 367 368 369 370 371 372
        if (use_peepholes) {
          // + W_ic|W_fc * C_t-1 for peephole connection
          blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
          blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
          blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D);
          // I_t, F_t
          act_gate(D2, xx_data + D, xx_data + D);
        } else {
          // I_t, F_t, O_t
          act_gate(D3, xx_data + D, xx_data + D);
        }

        // F_t * C_t-1
T
tensor-tang 已提交
373
        blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2);
B
Brian Liu 已提交
374
        // I_t * ~C_t
T
tensor-tang 已提交
375
        blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
B
Brian Liu 已提交
376
        // C_t = F_t * C_t-1 + I_t * ~C_t
T
tensor-tang 已提交
377 378
        blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);

B
Brian Liu 已提交
379 380 381 382 383 384 385 386
        if (use_peepholes) {
          // + W_oc * C_t for peephole connection
          blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
          blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
          // O_t
          act_gate(D, xx_data + D3, xx_data + D3);
        }

T
tensor-tang 已提交
387
        // hidden out= act_state(cellout) * outgate
T
tensor-tang 已提交
388
        act_cell(D, cell_out_data, xx_data + D2);
B
Brian Liu 已提交
389
        // H_t = O_t * act_state(C_t)
T
tensor-tang 已提交
390 391 392
        blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);

        // prev
T
tensor-tang 已提交
393 394
        prev_h_data = hidden_out_data;
        prev_c_data = cell_out_data;
T
tensor-tang 已提交
395

T
tensor-tang 已提交
396
        move_step();
B
Brian Liu 已提交
397 398
      }  // for each step in batch
    }    // for each batch
T
tensor-tang 已提交
399 400 401 402
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = platform::CPUDeviceContext;
T
tensor-tang 已提交
403
    INIT_BASE_INPUT_OUTPUT
B
Brian Liu 已提交
404
    if (x->lod()[0].size() == 2) {  // batch size == 1
T
tensor-tang 已提交
405
      SeqCompute(ctx);
T
tensor-tang 已提交
406
      return;
T
tensor-tang 已提交
407 408 409 410
    }
    INIT_BASE_SIZES
    INIT_VEC_FUNC

T
tensor-tang 已提交
411 412 413 414 415
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
    auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
T
tensor-tang 已提交
416

T
tensor-tang 已提交
417 418
    const T* x_data = x->data<T>();
    const T* wx_data = wx->data<T>();
T
tensor-tang 已提交
419
    const T* wh_data = wh->data<T>();
B
Brian Liu 已提交
420 421
    const T* bias_data = bias->data<T>();
    const T* wc_data = bias_data + D4;  // w_ic, w_fc, w_oc
T
tensor-tang 已提交
422 423 424 425 426 427 428
    auto place = ctx.GetPlace();
    T* xx_data = xx->mutable_data<T>(place);
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_c_out_data = batched_c_out->mutable_data<T>(place);
    T* batched_h_out_data = batched_h_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
    cell_out->mutable_data<T>(place);
T
tensor-tang 已提交
429

B
Brian Liu 已提交
430 431 432 433 434 435
    // use local variable
    framework::DDim check_dims({3, D});
    Tensor checked_cell;  // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
    auto checked_cell_data =
        checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());

T
tensor-tang 已提交
436
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
T
tensor-tang 已提交
437 438
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T
tensor-tang 已提交
439 440 441 442
    if (M > D4) {
      math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
                                        xx_data, bias->data<T>());
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
443 444
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
445 446 447
      batched_input->set_lod(xx->lod());
      math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, xx_data,
                                        wx_data, batched_input_data,
448
                                        bias->data<T>());
T
tensor-tang 已提交
449 450
    }

T
tensor-tang 已提交
451 452 453 454 455 456
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
    reordered_c0->Resize({max_bs, D});

B
Brian Liu 已提交
457 458 459 460 461 462 463 464 465 466 467 468
    T* prev_batch_h_data = nullptr;
    T* prev_batch_c_data = nullptr;
    T* cur_batch_in_data = batched_input_data;
    T* cur_batch_h_out_data = batched_h_out_data;
    T* cur_batch_c_out_data = batched_c_out_data;

    auto move_step = [&](int bs) {
      cur_batch_in_data += bs * D4;
      cur_batch_c_out_data += bs * D;
      cur_batch_h_out_data += bs * D;
    };

T
tensor-tang 已提交
469 470 471 472 473 474 475
    int tstart = 0;
    if (h0) {
      // reorder h0, c0
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
      T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
      const T* h0_data = h0->data<T>();
      const T* c0_data = c0->data<T>();
B
Brian Liu 已提交
476 477
      prev_batch_h_data = reordered_h0_data;
      prev_batch_c_data = reordered_c0_data;
T
tensor-tang 已提交
478 479 480 481 482 483 484 485
      size_t sz = sizeof(T) * D;
      for (int i = 0; i < max_bs; ++i) {
        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
        std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz);
        reordered_h0_data += D;
        reordered_c0_data += D;
      }
    } else {
B
Brian Liu 已提交
486 487 488 489 490 491 492 493 494 495
      // Compute with no H0/C0
      T* cur_in_data = cur_batch_in_data;
      T* cur_c_out_data = cur_batch_c_out_data;
      T* cur_h_out_data = cur_batch_h_out_data;

      // If step == 0 and there is no initialized hidden state, that is to say
      // the H0 is zeros. Then W_h * H_t-1 can be skiped

      for (int i = 0; i < max_bs; ++i) {  // iterate each data in 1st batch
        // ~C_t
T
tensor-tang 已提交
496
        act_cand(D, cur_in_data, cur_in_data);
B
Brian Liu 已提交
497 498 499 500 501 502 503 504 505 506

        if (use_peepholes) {
          // I_t, F_t
          act_gate(D2, cur_in_data + D, cur_in_data + D);
        } else {
          // I_t, F_t, O_t
          act_gate(D3, cur_in_data + D, cur_in_data + D);
        }

        // C_t = I_t * ~C_t
T
tensor-tang 已提交
507
        blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data);
B
Brian Liu 已提交
508 509 510 511 512 513 514 515 516 517

        if (use_peepholes) {
          // + W_oc * C_t for peephole connection
          blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
          blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
                    cur_in_data + D3);
          // O_t
          act_gate(D, cur_in_data + D3, cur_in_data + D3);
        }

T
tensor-tang 已提交
518 519
        // hidden out= act_state(cellout) * outgate
        act_cell(D, cur_c_out_data, cur_in_data + D2);
B
Brian Liu 已提交
520
        // H_t = O_t * act_state(C_t)
T
tensor-tang 已提交
521 522
        blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);

B
Brian Liu 已提交
523
        // move to next data in the same batch
T
tensor-tang 已提交
524 525 526 527
        cur_in_data += D4;
        cur_c_out_data += D;
        cur_h_out_data += D;
      }
B
Brian Liu 已提交
528 529 530 531 532

      // move to data for next timestep
      prev_batch_h_data = cur_batch_h_out_data;
      prev_batch_c_data = cur_batch_c_out_data;
      move_step(max_bs);
T
tensor-tang 已提交
533
      tstart = 1;
T
tensor-tang 已提交
534
    }
B
Brian Liu 已提交
535

T
tensor-tang 已提交
536 537 538 539
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
B
Brian Liu 已提交
540
      // + W_h * H_t-1
T
tensor-tang 已提交
541
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast<T>(1),
B
Brian Liu 已提交
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
                prev_batch_h_data, D, wh_data, D4, static_cast<T>(1),
                cur_batch_in_data, D4);

      T* cur_in_data = cur_batch_in_data;
      T* cur_c_out_data = cur_batch_c_out_data;
      T* cur_h_out_data = cur_batch_h_out_data;
      T* prev_c_data = prev_batch_c_data;  // NULL if no C0 in step0
      T* prev_h_data = prev_batch_h_data;  // NULL if no H0 in step0
      auto next_data_in_batch = [&]() {
        cur_in_data += D4;
        cur_c_out_data += D;
        cur_h_out_data += D;
        prev_c_data = prev_c_data ? prev_c_data + D : nullptr;
        prev_h_data = prev_h_data ? prev_h_data + D : nullptr;
      };

      for (int i = 0; i < cur_bs; ++i) {  // iterate each data in same batch
        // ~C_t
T
tensor-tang 已提交
560
        act_cand(D, cur_in_data, cur_in_data);
B
Brian Liu 已提交
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576

        if (use_peepholes) {
          // + W_ic|W_fc * C_t-1 for peephole connection
          blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
          blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
          blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D);
          // I_t, F_t
          act_gate(D2, cur_in_data + D, cur_in_data + D);
        } else {
          // I_t, F_t, O_t
          act_gate(D3, cur_in_data + D, cur_in_data + D);
        }

        // F_t * C_t-1
        blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2);
        // I_t * ~C_t
T
tensor-tang 已提交
577
        blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
B
Brian Liu 已提交
578
        // C_t = F_t * C_t-1 + I_t * ~C_t
T
tensor-tang 已提交
579
        blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
B
Brian Liu 已提交
580 581 582 583 584 585 586 587 588 589

        if (use_peepholes) {
          // + W_oc * C_t for peephole connection
          blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
          blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
                    cur_in_data + D3);
          // O_t
          act_gate(D, cur_in_data + D3, cur_in_data + D3);
        }

T
tensor-tang 已提交
590 591
        // hidden out= act_state(cellout) * outgate
        act_cell(D, cur_c_out_data, cur_in_data + D2);
B
Brian Liu 已提交
592
        // H_t = O_t * act_state(C_t)
T
tensor-tang 已提交
593 594
        blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);

B
Brian Liu 已提交
595 596
        // move to next data in same batch
        next_data_in_batch();
T
tensor-tang 已提交
597
      }
B
Brian Liu 已提交
598 599 600 601
      // move to data for next timestep
      prev_batch_h_data = cur_batch_h_out_data;
      prev_batch_c_data = cur_batch_c_out_data;
      move_step(cur_bs);
T
tensor-tang 已提交
602 603 604
    }

    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
605 606 607 608
    batched_h_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_h_out, hidden_out);
    batched_c_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_c_out, cell_out);
T
tensor-tang 已提交
609
  }
T
tensor-tang 已提交
610

T
tensor-tang 已提交
611
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
612
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
613 614 615 616 617
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }
T
tensor-tang 已提交
618 619 620
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
T
tensor-tang 已提交
621 622 623 624 625 626
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
T
tensor-tang 已提交
627
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
T
tensor-tang 已提交
628 629
                  paddle::framework::DefaultGradOpDescMaker<true>);

T
tensor-tang 已提交
630 631
REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
                       ops::FuisonLSTMKernel<double>);