fusion_gru_op.cc 20.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
T
tensor-tang 已提交
16
#include <cstring>  // for memcpy
T
tensor-tang 已提交
17
#include <string>
H
huangxu96 已提交
18
#include <vector>
19
#include "paddle/fluid/framework/op_version_registry.h"
20
#include "paddle/fluid/operators/jit/kernels.h"
T
tensor-tang 已提交
21
#include "paddle/fluid/operators/math/blas.h"
22
#include "paddle/fluid/operators/math/fc.h"
T
tensor-tang 已提交
23
#include "paddle/fluid/operators/math/sequence2batch.h"
A
Adam 已提交
24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
T
tensor-tang 已提交
27 28 29 30 31

namespace paddle {
namespace operators {

void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
32 33 34 35 36
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
T
tensor-tang 已提交
37
  auto x_dims = ctx->GetInputDim("X");
38 39 40 41 42 43 44 45 46
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
                        ? framework::flatten_to_2d(x_dims, 1)
                        : x_dims;
  PADDLE_ENFORCE_EQ(
      x_mat_dims.size(), 2,
      platform::errors::InvalidArgument("The size of input X dims should be 2, "
                                        "or 3 with second dimension equal to "
                                        "1, but now Input X dim is:[%s] ",
                                        x_dims));
T
tensor-tang 已提交
47 48 49

  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
50 51 52 53
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX dim size is:%d, WeightX dim is:[%s] ",
                        wx_dims.size(), wx_dims));
54 55 56 57 58 59 60 61
  PADDLE_ENFORCE_EQ(
      wx_dims[0], x_mat_dims[1],
      platform::errors::InvalidArgument(
          "The first dimension of flattened WeightX"
          "should equal to last dimension of flattened input X, but "
          "received fattened WeightX dimension is:%d, flattened X dimension "
          "is:%d",
          wx_dims[0], x_mat_dims[1]));
T
tensor-tang 已提交
62 63 64

  int frame_size = wx_dims[1] / 3;
  auto wh_dims = ctx->GetInputDim("WeightH");
65

T
tensor-tang 已提交
66
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
67 68 69 70
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH dim size is:%d, WeightH dim is:[%s]",
                        wh_dims.size(), wh_dims));
T
tensor-tang 已提交
71
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
72 73 74 75 76 77
                    platform::errors::InvalidArgument(
                        "The first dimension of WeightH "
                        "should equal to frame_size, but received WeightH's "
                        "first dimension is: "
                        "%d, frame size is:%d",
                        wh_dims[0], frame_size));
T
tensor-tang 已提交
78
  PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
79 80 81 82 83
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 3 * frame_size, but received WeightH "
                        "is:%d, frame size is:%d",
                        wh_dims[1], frame_size));
T
tensor-tang 已提交
84

85
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
86 87
    auto h0_dims = ctx->GetInputDim("H0");
    PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
88 89 90 91
                      platform::errors::InvalidArgument(
                          "The width of H0 must be equal to frame_size, but "
                          "receiced the width of H0 is:%d, frame size is:%d",
                          h0_dims[1], frame_size));
T
tensor-tang 已提交
92
  }
93
  if (ctx->HasInput("Bias")) {
T
tensor-tang 已提交
94
    auto b_dims = ctx->GetInputDim("Bias");
95 96 97 98 99
    PADDLE_ENFORCE_EQ(b_dims.size(), 2,
                      platform::errors::InvalidArgument(
                          "The rank of Input(Bias) should be 2, but received "
                          "Bias rank is:%d, Bias dim is:[%s]",
                          b_dims.size(), b_dims));
T
tensor-tang 已提交
100
    PADDLE_ENFORCE_EQ(b_dims[0], 1,
101 102 103 104
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Bias) should be 1, but "
                          "received Bias first dim is:%d, Bias dim is:[%s]",
                          b_dims[0], b_dims));
T
tensor-tang 已提交
105
    PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
106 107 108 109
                      platform::errors::InvalidArgument(
                          "The shape of Bias must be [1, frame_size * 3], but "
                          "received bias dim is:[%s], frame size is:%d",
                          b_dims, frame_size));
T
tensor-tang 已提交
110
  }
111
  framework::DDim out_dims({x_mat_dims[0], frame_size});
T
tensor-tang 已提交
112 113
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->ShareLoD("X", "Hidden");
T
tensor-tang 已提交
114
  int xx_width;
T
tensor-tang 已提交
115
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
116 117
    xx_width = wx_dims[1];
  } else {
118
    xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1];
119 120 121 122 123 124
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
                   "fusion_gru");
125
    ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]});
T
tensor-tang 已提交
126
    ctx->SetOutputDim("BatchedOut", out_dims);
T
tensor-tang 已提交
127
  }
128
  ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width});
T
tensor-tang 已提交
129
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
130 131 132 133
}

framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
A
Adam 已提交
134 135 136
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
137
  if (this->CanMKLDNNBeUsed(ctx)) {
A
Adam 已提交
138 139 140 141
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
  }
#endif
142
  return framework::OpKernelType(
A
Adam 已提交
143 144
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
      library);
T
tensor-tang 已提交
145 146 147
}

void FusionGRUOpMaker::Make() {
T
tensor-tang 已提交
148 149
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
150
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
151 152
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
T
tensor-tang 已提交
153 154 155 156 157
  AddInput("H0",
           "(Tensor, optional) The initial hidden state is an optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size, D is the hidden size.")
      .AsDispensable();
T
tensor-tang 已提交
158 159 160 161
  AddInput("WeightX",
           "(Tensor) The FC weight with shape (M x 3D),"
           "where M is the dim size of x, D is the hidden size. ");
  AddInput("WeightH",
T
tensor-tang 已提交
162 163 164 165 166
           "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
           "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
           "Acutally they are D x 2D and D x D two part weights."
           "{W_update, W_reset; W_state}"
           "{D x (D + D); D x D}");
T
tensor-tang 已提交
167
  AddInput("Bias",
T
tensor-tang 已提交
168 169 170
           "(Tensor, optional) (1 x 3D)."
           "Almost same as GRUOp."
           "Note: if have FC bias it should be added on this bias.")
T
tensor-tang 已提交
171
      .AsDispensable();
T
tensor-tang 已提交
172 173
  AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.")
      .AsIntermediate();
T
tensor-tang 已提交
174
  AddOutput("XX",
T
tensor-tang 已提交
175
            "(LoDTensor) the result after X * WeightX (size is T x 3D)"
T
tensor-tang 已提交
176 177 178
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size, M is the dim size of x input.")
T
tensor-tang 已提交
179
      .AsIntermediate();
T
tensor-tang 已提交
180 181 182 183
  AddOutput("BatchedInput",
            "(LoDTensor) This is the batched result of input X"
            "or the batched result after fc, shape (T x 3D)")
      .AsIntermediate();
T
tensor-tang 已提交
184
  AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
T
tensor-tang 已提交
185
      .AsIntermediate();
T
tensor-tang 已提交
186
  AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
T
tensor-tang 已提交
187 188 189 190 191 192 193 194 195 196
  AddAttr<std::string>("activation",
                       "(string, default tanh) "
                       "The activation type used for output candidate {h}_t.")
      .SetDefault("tanh");
  AddAttr<std::string>(
      "gate_activation",
      "(string, default sigmoid) "
      "The activation type used in update gate and reset gate.")
      .SetDefault("sigmoid");
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
197
                "(bool, default: False) "
T
tensor-tang 已提交
198 199
                "whether to compute reversed GRU.")
      .SetDefault(false);
T
tensor-tang 已提交
200
  AddAttr<bool>("use_seq",
翟飞跃 已提交
201
                "(bool, default: True) "
T
tensor-tang 已提交
202 203
                "whether to use seq mode to compute GRU.")
      .SetDefault(true);
A
Adam 已提交
204 205 206 207
  AddAttr<bool>("origin_mode",
                "bool"
                "use origin mode in article https://arxiv.org/abs/1412.3555")
      .SetDefault(false);
A
Adam 已提交
208 209 210
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
A
Adam 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  AddAttr<std::string>(
      "mkldnn_data_type",
      "(string, default \"float32\"). Data type of mkldnn kernel")
      .SetDefault("float32")
      .InEnum({"float32", "int8", "bfloat16"});
  AddAttr<float>("Scale_data",
                 "Scale to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(1.0f);
  AddAttr<float>("Shift_data",
                 "Shift to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(0.0f);
  AddAttr<std::vector<float>>("Scale_weights",
                              "Scale_weights to be used for int8 weights data."
                              "Only used with MKL-DNN INT8.")
      .SetDefault({1.0f});
  AddAttr<bool>("force_fp32_output",
                "(bool, default false) Force INT8 kernel output FP32, only "
                "used in MKL-DNN INT8")
      .SetDefault(false);
T
tensor-tang 已提交
232 233 234 235 236 237 238
  AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, 
more details can refer to GRU op.
)DOC");
}

T
tensor-tang 已提交
239
template <typename T>
T
tensor-tang 已提交
240 241
class FusionGRUKernel : public framework::OpKernel<T> {
 public:
T
tensor-tang 已提交
242
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
243
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
244 245 246 247 248 249
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }

250 251 252 253 254 255 256 257 258 259 260
#define INIT_BASE_DEFINES                                     \
  auto* x = ctx.Input<LoDTensor>("X");                        \
  auto* wh = ctx.Input<Tensor>("WeightH");                    \
  auto* xx = ctx.Output<LoDTensor>("XX");                     \
  auto x_lod = x->lod();                                      \
  auto x_dims = x->dims(); /* T x M*/                         \
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)    \
                        ? framework::flatten_to_2d(x_dims, 1) \
                        : x_dims;                             \
  auto wh_dims = wh->dims(); /* D x 3D*/                      \
  const int total_T = x_mat_dims[0];                          \
T
tensor-tang 已提交
261 262
  const int D3 = wh_dims[1]

263 264 265 266 267 268
#define INIT_OTHER_DEFINES                                                   \
  auto* h0 = ctx.Input<Tensor>("H0");                                        \
  auto* wx = ctx.Input<Tensor>("WeightX");                                   \
  auto* bias = ctx.Input<Tensor>("Bias");                                    \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                        \
  bool is_reverse = ctx.Attr<bool>("is_reverse");                            \
269
  const int M = x_mat_dims[1];                                               \
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
  const int D = wh_dims[0];                                                  \
  const int D2 = D * 2;                                                      \
  const jit::gru_attr_t attr(                                                \
      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("activation")));              \
  jit::gru_t one_step;                                                       \
  auto ComputeH1 =                                                           \
      jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At(  \
          attr);                                                             \
  auto ComputeHtPart1 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  auto ComputeHtPart2 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  const T* x_data = x->data<T>();                                            \
  const T* wx_data = wx->data<T>();                                          \
  const T* wh_data = wh->data<T>();                                          \
  auto place = ctx.GetPlace();                                               \
T
tensor-tang 已提交
289
  T* xx_data = xx->mutable_data<T>(place)
T
tensor-tang 已提交
290

T
tensor-tang 已提交
291 292
  void SeqCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
293 294
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
295
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
296
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
T
tensor-tang 已提交
297
    const T* wh_state_data = wh_data + D * D2;
T
tensor-tang 已提交
298
    T* hidden_out_data = hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
299
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
300 301 302 303 304

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::FCFunctor<DeviceContext, T> fc;
    fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
       bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321

    int xx_offset = D3;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 3;
      hidden_out_data = hidden_out_data + offset;
      xx_offset = -D3;
      gate_offset = -D;
    }
    auto move_step = [&]() {
      xx_data = xx_data + xx_offset;
      hidden_out_data = hidden_out_data + gate_offset;
    };
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
T
tensor-tang 已提交
322
      const T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
323 324 325 326
      int tstart = 0;
      if (h0_data) {
        prev_hidden_data = h0_data + bid * D;
      } else {
327 328
        one_step.gates = xx_data;
        one_step.ht = hidden_out_data;
329
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
330 331 332 333 334 335 336 337 338
        prev_hidden_data = hidden_out_data;
        tstart = 1;
        move_step();
      }
      for (int step = tstart; step < seq_len; ++step) {
        // gemm prev * (Wu + Wr)
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
                  prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
                  D3);
339 340 341
        one_step.gates = xx_data;
        one_step.ht_1 = prev_hidden_data;
        one_step.ht = hidden_out_data;
342
        ComputeHtPart1(&one_step, &attr);
T
tensor-tang 已提交
343 344 345 346
        // gemm rt * Ws
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
                  hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
                  xx_data + D2, D3);
347
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
348 349 350 351 352 353 354 355
        // save prev
        prev_hidden_data = hidden_out_data;
        move_step();
      }
    }
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
T
tensor-tang 已提交
356
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
357 358
    INIT_BASE_DEFINES;
    if (x_lod[0].size() == 2) {
359
      xx->Resize({total_T, D3});
T
tensor-tang 已提交
360 361 362
      SeqCompute(ctx);
      return;
    }
T
tensor-tang 已提交
363
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
364 365 366
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T
tensor-tang 已提交
367 368 369
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_out_data = batched_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
370 371 372
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
373 374

    math::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
375
    if (M > D3) {
376 377
      fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
378
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
379 380
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
381
      batched_input->set_lod(xx->lod());
382 383
      fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
384 385
    }

T
tensor-tang 已提交
386 387 388 389
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
T
tensor-tang 已提交
390

T
tensor-tang 已提交
391
    int tstart = 0;
T
tensor-tang 已提交
392
    T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
393
    if (h0) {
T
tensor-tang 已提交
394
      // reorder h0
T
tensor-tang 已提交
395
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T
tensor-tang 已提交
396 397 398 399 400 401 402
      const T* h0_data = h0->data<T>();
      prev_hidden_data = reordered_h0_data;
      size_t sz = sizeof(T) * D;
      for (int i = 0; i < max_bs; ++i) {
        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
        reordered_h0_data += D;
      }
T
tensor-tang 已提交
403
    } else {
T
tensor-tang 已提交
404 405 406 407 408
      // compute without h0
      T* cur_in_data = batched_input_data;
      T* cur_out_data = batched_out_data;
      // W: {W_update, W_reset; W_state}
      for (int i = 0; i < max_bs; ++i) {
409 410
        one_step.gates = cur_in_data;
        one_step.ht = cur_out_data;
411
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
412 413 414 415 416 417
        // add offset
        cur_in_data += D3;
        cur_out_data += D;
      }
      tstart = 1;
      prev_hidden_data = batched_out_data;
T
tensor-tang 已提交
418
    }
T
tensor-tang 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432
    // Then start from next
    const T* wh_state_data = wh_data + D * D2;
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
    batched_input_data = batched_input_data + tstart * max_bs * D3;
    batched_out_data = batched_out_data + tstart * max_bs * D;
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      // gemm prev * (Wu + Wr)
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast<T>(1),
                prev_hidden_data, D, wh_data, D2, static_cast<T>(1),
                batched_input_data, D3);

      T* cur_batched_data = batched_input_data;
433
      T* cur_out_data = batched_out_data;
T
tensor-tang 已提交
434 435
      T* cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
436 437 438
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
439
        ComputeHtPart1(&one_step, &attr);
440

T
tensor-tang 已提交
441 442
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
443
        cur_out_data += D;
T
tensor-tang 已提交
444 445
      }

T
tensor-tang 已提交
446
      cur_batched_data = batched_input_data;
447
      cur_out_data = batched_out_data;
T
tensor-tang 已提交
448
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
449
                cur_out_data, D, wh_state_data, D, static_cast<T>(1),
T
tensor-tang 已提交
450 451 452 453
                cur_batched_data + D2, D3);

      cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
454 455 456
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
457
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
458 459 460
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
        cur_out_data += D;
T
tensor-tang 已提交
461
      }
T
tensor-tang 已提交
462 463 464
      prev_hidden_data = batched_out_data;
      batched_out_data = cur_out_data;
      batched_input_data = cur_batched_data;
T
tensor-tang 已提交
465
    }
T
tensor-tang 已提交
466

T
tensor-tang 已提交
467
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
468 469
    batched_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_out, hidden_out);
T
tensor-tang 已提交
470
  }
T
tensor-tang 已提交
471 472
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
473 474 475 476 477 478
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
479 480
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);

T
tensor-tang 已提交
481 482
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
                       ops::FusionGRUKernel<double>);
483 484 485 486 487 488 489 490 491 492

/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru)
    .AddCheckpoint(
        R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "Scale_weights",
            "The added attribute 'Scale_weights' is not yet "
            "registered.",
            {1.0f}));