fusion_gru_op.cc 20.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
T
tensor-tang 已提交
16
#include <cstring>  // for memcpy
T
tensor-tang 已提交
17
#include <string>
H
huangxu96 已提交
18
#include <vector>
19
#include "paddle/fluid/framework/op_version_registry.h"
20
#include "paddle/fluid/operators/jit/kernels.h"
21
#include "paddle/fluid/operators/math/fc.h"
22
#include "paddle/phi/kernels/funcs/blas/blas.h"
F
Feiyu Chan 已提交
23
#include "paddle/phi/kernels/funcs/sequence2batch.h"
A
Adam 已提交
24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
T
tensor-tang 已提交
27 28 29 30 31

namespace paddle {
namespace operators {

void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
32 33 34 35 36
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
T
tensor-tang 已提交
37
  auto x_dims = ctx->GetInputDim("X");
38
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
39
                        ? phi::flatten_to_2d(x_dims, 1)
40 41 42 43 44 45 46
                        : x_dims;
  PADDLE_ENFORCE_EQ(
      x_mat_dims.size(), 2,
      platform::errors::InvalidArgument("The size of input X dims should be 2, "
                                        "or 3 with second dimension equal to "
                                        "1, but now Input X dim is:[%s] ",
                                        x_dims));
T
tensor-tang 已提交
47 48 49

  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
50 51 52 53
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX dim size is:%d, WeightX dim is:[%s] ",
                        wx_dims.size(), wx_dims));
54 55 56 57 58 59 60 61
  PADDLE_ENFORCE_EQ(
      wx_dims[0], x_mat_dims[1],
      platform::errors::InvalidArgument(
          "The first dimension of flattened WeightX"
          "should equal to last dimension of flattened input X, but "
          "received fattened WeightX dimension is:%d, flattened X dimension "
          "is:%d",
          wx_dims[0], x_mat_dims[1]));
T
tensor-tang 已提交
62 63 64

  int frame_size = wx_dims[1] / 3;
  auto wh_dims = ctx->GetInputDim("WeightH");
65

T
tensor-tang 已提交
66
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
67 68 69 70
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH dim size is:%d, WeightH dim is:[%s]",
                        wh_dims.size(), wh_dims));
T
tensor-tang 已提交
71
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
72 73 74 75 76 77
                    platform::errors::InvalidArgument(
                        "The first dimension of WeightH "
                        "should equal to frame_size, but received WeightH's "
                        "first dimension is: "
                        "%d, frame size is:%d",
                        wh_dims[0], frame_size));
T
tensor-tang 已提交
78
  PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
79 80 81 82 83
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 3 * frame_size, but received WeightH "
                        "is:%d, frame size is:%d",
                        wh_dims[1], frame_size));
T
tensor-tang 已提交
84

85
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
86 87
    auto h0_dims = ctx->GetInputDim("H0");
    PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
88 89 90 91
                      platform::errors::InvalidArgument(
                          "The width of H0 must be equal to frame_size, but "
                          "receiced the width of H0 is:%d, frame size is:%d",
                          h0_dims[1], frame_size));
T
tensor-tang 已提交
92
  }
93
  if (ctx->HasInput("Bias")) {
T
tensor-tang 已提交
94
    auto b_dims = ctx->GetInputDim("Bias");
95 96 97 98 99
    PADDLE_ENFORCE_EQ(b_dims.size(), 2,
                      platform::errors::InvalidArgument(
                          "The rank of Input(Bias) should be 2, but received "
                          "Bias rank is:%d, Bias dim is:[%s]",
                          b_dims.size(), b_dims));
T
tensor-tang 已提交
100
    PADDLE_ENFORCE_EQ(b_dims[0], 1,
101 102 103 104
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Bias) should be 1, but "
                          "received Bias first dim is:%d, Bias dim is:[%s]",
                          b_dims[0], b_dims));
T
tensor-tang 已提交
105
    PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
106 107 108 109
                      platform::errors::InvalidArgument(
                          "The shape of Bias must be [1, frame_size * 3], but "
                          "received bias dim is:[%s], frame size is:%d",
                          b_dims, frame_size));
T
tensor-tang 已提交
110
  }
111
  framework::DDim out_dims({x_mat_dims[0], frame_size});
T
tensor-tang 已提交
112 113
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->ShareLoD("X", "Hidden");
T
tensor-tang 已提交
114
  int xx_width;
T
tensor-tang 已提交
115
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
116 117
    xx_width = wx_dims[1];
  } else {
118
    xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1];
119 120 121 122 123 124
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
                   "fusion_gru");
125
    ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]});
T
tensor-tang 已提交
126
    ctx->SetOutputDim("BatchedOut", out_dims);
T
tensor-tang 已提交
127
  }
128
  ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width});
T
tensor-tang 已提交
129
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
130 131 132 133
}

framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
A
Adam 已提交
134 135
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
136
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
A
Adam 已提交
137
#ifdef PADDLE_WITH_MKLDNN
138
  if (this->CanMKLDNNBeUsed(ctx, data_type)) {
A
Adam 已提交
139 140 141 142
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
  }
#endif
143
  return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
T
tensor-tang 已提交
144 145 146
}

void FusionGRUOpMaker::Make() {
T
tensor-tang 已提交
147 148
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
149
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
150 151
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
T
tensor-tang 已提交
152 153 154 155 156
  AddInput("H0",
           "(Tensor, optional) The initial hidden state is an optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size, D is the hidden size.")
      .AsDispensable();
T
tensor-tang 已提交
157 158 159 160
  AddInput("WeightX",
           "(Tensor) The FC weight with shape (M x 3D),"
           "where M is the dim size of x, D is the hidden size. ");
  AddInput("WeightH",
T
tensor-tang 已提交
161 162 163 164 165
           "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
           "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
           "Acutally they are D x 2D and D x D two part weights."
           "{W_update, W_reset; W_state}"
           "{D x (D + D); D x D}");
T
tensor-tang 已提交
166
  AddInput("Bias",
T
tensor-tang 已提交
167 168 169
           "(Tensor, optional) (1 x 3D)."
           "Almost same as GRUOp."
           "Note: if have FC bias it should be added on this bias.")
T
tensor-tang 已提交
170
      .AsDispensable();
T
tensor-tang 已提交
171 172
  AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.")
      .AsIntermediate();
T
tensor-tang 已提交
173
  AddOutput("XX",
T
tensor-tang 已提交
174
            "(LoDTensor) the result after X * WeightX (size is T x 3D)"
T
tensor-tang 已提交
175 176 177
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size, M is the dim size of x input.")
T
tensor-tang 已提交
178
      .AsIntermediate();
T
tensor-tang 已提交
179 180 181 182
  AddOutput("BatchedInput",
            "(LoDTensor) This is the batched result of input X"
            "or the batched result after fc, shape (T x 3D)")
      .AsIntermediate();
T
tensor-tang 已提交
183
  AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
T
tensor-tang 已提交
184
      .AsIntermediate();
T
tensor-tang 已提交
185
  AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
T
tensor-tang 已提交
186 187 188 189 190 191 192 193 194 195
  AddAttr<std::string>("activation",
                       "(string, default tanh) "
                       "The activation type used for output candidate {h}_t.")
      .SetDefault("tanh");
  AddAttr<std::string>(
      "gate_activation",
      "(string, default sigmoid) "
      "The activation type used in update gate and reset gate.")
      .SetDefault("sigmoid");
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
196
                "(bool, default: False) "
T
tensor-tang 已提交
197 198
                "whether to compute reversed GRU.")
      .SetDefault(false);
T
tensor-tang 已提交
199
  AddAttr<bool>("use_seq",
翟飞跃 已提交
200
                "(bool, default: True) "
T
tensor-tang 已提交
201 202
                "whether to use seq mode to compute GRU.")
      .SetDefault(true);
A
Adam 已提交
203 204 205 206
  AddAttr<bool>("origin_mode",
                "bool"
                "use origin mode in article https://arxiv.org/abs/1412.3555")
      .SetDefault(false);
A
Adam 已提交
207 208 209
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
A
Adam 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  AddAttr<std::string>(
      "mkldnn_data_type",
      "(string, default \"float32\"). Data type of mkldnn kernel")
      .SetDefault("float32")
      .InEnum({"float32", "int8", "bfloat16"});
  AddAttr<float>("Scale_data",
                 "Scale to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(1.0f);
  AddAttr<float>("Shift_data",
                 "Shift to be used for int8 input/output data."
                 "Only used with MKL-DNN INT8.")
      .SetDefault(0.0f);
  AddAttr<std::vector<float>>("Scale_weights",
                              "Scale_weights to be used for int8 weights data."
                              "Only used with MKL-DNN INT8.")
      .SetDefault({1.0f});
  AddAttr<bool>("force_fp32_output",
                "(bool, default false) Force INT8 kernel output FP32, only "
                "used in MKL-DNN INT8")
      .SetDefault(false);
T
tensor-tang 已提交
231 232 233 234 235 236 237
  AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, 
more details can refer to GRU op.
)DOC");
}

T
tensor-tang 已提交
238
template <typename T>
T
tensor-tang 已提交
239 240
class FusionGRUKernel : public framework::OpKernel<T> {
 public:
T
tensor-tang 已提交
241
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
242
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
243 244 245 246 247 248
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }

249 250 251 252 253 254 255
#define INIT_BASE_DEFINES                                  \
  auto* x = ctx.Input<LoDTensor>("X");                     \
  auto* wh = ctx.Input<Tensor>("WeightH");                 \
  auto* xx = ctx.Output<LoDTensor>("XX");                  \
  auto x_lod = x->lod();                                   \
  auto x_dims = x->dims(); /* T x M*/                      \
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \
256
                        ? phi::flatten_to_2d(x_dims, 1)    \
257 258 259
                        : x_dims;                          \
  auto wh_dims = wh->dims(); /* D x 3D*/                   \
  const int total_T = x_mat_dims[0];                       \
T
tensor-tang 已提交
260 261
  const int D3 = wh_dims[1]

262 263 264 265 266 267
#define INIT_OTHER_DEFINES                                                   \
  auto* h0 = ctx.Input<Tensor>("H0");                                        \
  auto* wx = ctx.Input<Tensor>("WeightX");                                   \
  auto* bias = ctx.Input<Tensor>("Bias");                                    \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                        \
  bool is_reverse = ctx.Attr<bool>("is_reverse");                            \
268
  const int M = x_mat_dims[1];                                               \
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
  const int D = wh_dims[0];                                                  \
  const int D2 = D * 2;                                                      \
  const jit::gru_attr_t attr(                                                \
      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("activation")));              \
  jit::gru_t one_step;                                                       \
  auto ComputeH1 =                                                           \
      jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At(  \
          attr);                                                             \
  auto ComputeHtPart1 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  auto ComputeHtPart2 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  const T* x_data = x->data<T>();                                            \
  const T* wx_data = wx->data<T>();                                          \
  const T* wh_data = wh->data<T>();                                          \
  auto place = ctx.GetPlace();                                               \
T
tensor-tang 已提交
288
  T* xx_data = xx->mutable_data<T>(place)
T
tensor-tang 已提交
289

T
tensor-tang 已提交
290 291
  void SeqCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
292 293
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
294
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
295
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
T
tensor-tang 已提交
296
    const T* wh_state_data = wh_data + D * D2;
T
tensor-tang 已提交
297
    T* hidden_out_data = hidden_out->mutable_data<T>(place);
298
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
299 300 301 302 303

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::FCFunctor<DeviceContext, T> fc;
    fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
       bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320

    int xx_offset = D3;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 3;
      hidden_out_data = hidden_out_data + offset;
      xx_offset = -D3;
      gate_offset = -D;
    }
    auto move_step = [&]() {
      xx_data = xx_data + xx_offset;
      hidden_out_data = hidden_out_data + gate_offset;
    };
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
T
tensor-tang 已提交
321
      const T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
322 323 324 325
      int tstart = 0;
      if (h0_data) {
        prev_hidden_data = h0_data + bid * D;
      } else {
326 327
        one_step.gates = xx_data;
        one_step.ht = hidden_out_data;
328
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
329 330 331 332 333 334 335 336 337
        prev_hidden_data = hidden_out_data;
        tstart = 1;
        move_step();
      }
      for (int step = tstart; step < seq_len; ++step) {
        // gemm prev * (Wu + Wr)
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
                  prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
                  D3);
338 339 340
        one_step.gates = xx_data;
        one_step.ht_1 = prev_hidden_data;
        one_step.ht = hidden_out_data;
341
        ComputeHtPart1(&one_step, &attr);
T
tensor-tang 已提交
342 343 344 345
        // gemm rt * Ws
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
                  hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
                  xx_data + D2, D3);
346
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
347 348 349 350 351 352 353 354
        // save prev
        prev_hidden_data = hidden_out_data;
        move_step();
      }
    }
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
T
tensor-tang 已提交
355
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
356 357
    INIT_BASE_DEFINES;
    if (x_lod[0].size() == 2) {
358
      xx->Resize({total_T, D3});
T
tensor-tang 已提交
359 360 361
      SeqCompute(ctx);
      return;
    }
T
tensor-tang 已提交
362
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
363 364 365
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T
tensor-tang 已提交
366 367 368
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_out_data = batched_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
369
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
370
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
F
Feiyu Chan 已提交
371
    phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
372 373

    math::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
374
    if (M > D3) {
375 376
      fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
377
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
378 379
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
380
      batched_input->set_lod(xx->lod());
381 382
      fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
383 384
    }

T
tensor-tang 已提交
385 386 387 388
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
T
tensor-tang 已提交
389

T
tensor-tang 已提交
390
    int tstart = 0;
T
tensor-tang 已提交
391
    T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
392
    if (h0) {
T
tensor-tang 已提交
393
      // reorder h0
T
tensor-tang 已提交
394
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T
tensor-tang 已提交
395 396 397 398 399 400 401
      const T* h0_data = h0->data<T>();
      prev_hidden_data = reordered_h0_data;
      size_t sz = sizeof(T) * D;
      for (int i = 0; i < max_bs; ++i) {
        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
        reordered_h0_data += D;
      }
T
tensor-tang 已提交
402
    } else {
T
tensor-tang 已提交
403 404 405 406 407
      // compute without h0
      T* cur_in_data = batched_input_data;
      T* cur_out_data = batched_out_data;
      // W: {W_update, W_reset; W_state}
      for (int i = 0; i < max_bs; ++i) {
408 409
        one_step.gates = cur_in_data;
        one_step.ht = cur_out_data;
410
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
411 412 413 414 415 416
        // add offset
        cur_in_data += D3;
        cur_out_data += D;
      }
      tstart = 1;
      prev_hidden_data = batched_out_data;
T
tensor-tang 已提交
417
    }
T
tensor-tang 已提交
418 419 420 421 422 423 424 425 426 427 428 429 430 431
    // Then start from next
    const T* wh_state_data = wh_data + D * D2;
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
    batched_input_data = batched_input_data + tstart * max_bs * D3;
    batched_out_data = batched_out_data + tstart * max_bs * D;
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      // gemm prev * (Wu + Wr)
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast<T>(1),
                prev_hidden_data, D, wh_data, D2, static_cast<T>(1),
                batched_input_data, D3);

      T* cur_batched_data = batched_input_data;
432
      T* cur_out_data = batched_out_data;
T
tensor-tang 已提交
433 434
      T* cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
435 436 437
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
438
        ComputeHtPart1(&one_step, &attr);
439

T
tensor-tang 已提交
440 441
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
442
        cur_out_data += D;
T
tensor-tang 已提交
443 444
      }

T
tensor-tang 已提交
445
      cur_batched_data = batched_input_data;
446
      cur_out_data = batched_out_data;
T
tensor-tang 已提交
447
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
448
                cur_out_data, D, wh_state_data, D, static_cast<T>(1),
T
tensor-tang 已提交
449 450 451 452
                cur_batched_data + D2, D3);

      cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
453 454 455
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
456
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
457 458 459
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
        cur_out_data += D;
T
tensor-tang 已提交
460
      }
T
tensor-tang 已提交
461 462 463
      prev_hidden_data = batched_out_data;
      batched_out_data = cur_out_data;
      batched_input_data = cur_batched_data;
T
tensor-tang 已提交
464
    }
T
tensor-tang 已提交
465

F
Feiyu Chan 已提交
466
    phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
467 468
    batched_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_out, hidden_out);
T
tensor-tang 已提交
469
  }
T
tensor-tang 已提交
470 471
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
472 473 474 475 476 477
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
478 479
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);

T
tensor-tang 已提交
480 481
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
                       ops::FusionGRUKernel<double>);
482 483 484 485 486 487 488 489 490

/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru)
    .AddCheckpoint(
        R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "Scale_weights",
            "The added attribute 'Scale_weights' is not yet "
            "registered.",
491
            std::vector<float>{1.0f}));