fusion_gru_op.cc 18.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
T
tensor-tang 已提交
16
#include <cstring>  // for memcpy
T
tensor-tang 已提交
17
#include <string>
18
#include "paddle/fluid/operators/jit/kernels.h"
T
tensor-tang 已提交
19
#include "paddle/fluid/operators/math/blas.h"
20
#include "paddle/fluid/operators/math/fc.h"
T
tensor-tang 已提交
21
#include "paddle/fluid/operators/math/sequence2batch.h"
A
Adam 已提交
22 23 24
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
T
tensor-tang 已提交
25 26 27 28 29

namespace paddle {
namespace operators {

void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
30 31 32 33 34 35
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");

  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
T
tensor-tang 已提交
36 37

  auto x_dims = ctx->GetInputDim("X");
38 39 40 41 42
  PADDLE_ENFORCE_EQ(x_dims.size(), 2,
                    platform::errors::InvalidArgument(
                        "Input(X)'s rank must be 2, but received input dim "
                        "size is:%d, input dim is:[%s]",
                        x_dims.size(), x_dims));
T
tensor-tang 已提交
43 44 45

  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
46 47 48 49
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX dim size is:%d, WeightX dim is:[%s] ",
                        wx_dims.size(), wx_dims));
T
tensor-tang 已提交
50
  PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
51 52 53 54 55
                    platform::errors::InvalidArgument(
                        "The first dimension of Input(WeightX) "
                        "should equal to second dimension of input x, but "
                        "received WeightX dimension is:%d, x dimension is:%d",
                        wx_dims[0], x_dims[1]));
T
tensor-tang 已提交
56 57 58

  int frame_size = wx_dims[1] / 3;
  auto wh_dims = ctx->GetInputDim("WeightH");
59

T
tensor-tang 已提交
60
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
61 62 63 64
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH dim size is:%d, WeightH dim is:[%s]",
                        wh_dims.size(), wh_dims));
T
tensor-tang 已提交
65
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
66 67 68 69 70 71
                    platform::errors::InvalidArgument(
                        "The first dimension of WeightH "
                        "should equal to frame_size, but received WeightH's "
                        "first dimension is: "
                        "%d, frame size is:%d",
                        wh_dims[0], frame_size));
T
tensor-tang 已提交
72
  PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
73 74 75 76 77
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 3 * frame_size, but received WeightH "
                        "is:%d, frame size is:%d",
                        wh_dims[1], frame_size));
T
tensor-tang 已提交
78

79
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
80 81
    auto h0_dims = ctx->GetInputDim("H0");
    PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
82 83 84 85
                      platform::errors::InvalidArgument(
                          "The width of H0 must be equal to frame_size, but "
                          "receiced the width of H0 is:%d, frame size is:%d",
                          h0_dims[1], frame_size));
T
tensor-tang 已提交
86
  }
87
  if (ctx->HasInput("Bias")) {
T
tensor-tang 已提交
88
    auto b_dims = ctx->GetInputDim("Bias");
89 90 91 92 93
    PADDLE_ENFORCE_EQ(b_dims.size(), 2,
                      platform::errors::InvalidArgument(
                          "The rank of Input(Bias) should be 2, but received "
                          "Bias rank is:%d, Bias dim is:[%s]",
                          b_dims.size(), b_dims));
T
tensor-tang 已提交
94
    PADDLE_ENFORCE_EQ(b_dims[0], 1,
95 96 97 98
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Bias) should be 1, but "
                          "received Bias first dim is:%d, Bias dim is:[%s]",
                          b_dims[0], b_dims));
T
tensor-tang 已提交
99
    PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
100 101 102 103
                      platform::errors::InvalidArgument(
                          "The shape of Bias must be [1, frame_size * 3], but "
                          "received bias dim is:[%s], frame size is:%d",
                          b_dims, frame_size));
T
tensor-tang 已提交
104
  }
T
tensor-tang 已提交
105 106 107
  framework::DDim out_dims({x_dims[0], frame_size});
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->ShareLoD("X", "Hidden");
T
tensor-tang 已提交
108
  int xx_width;
T
tensor-tang 已提交
109
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
110 111 112
    xx_width = wx_dims[1];
  } else {
    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
113 114 115 116 117 118
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
                   "fusion_gru");
T
tensor-tang 已提交
119 120
    ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
    ctx->SetOutputDim("BatchedOut", out_dims);
T
tensor-tang 已提交
121
  }
T
tensor-tang 已提交
122 123
  ctx->SetOutputDim("XX", {x_dims[0], xx_width});
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
124 125 126 127
}

framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
A
Adam 已提交
128 129 130 131 132 133 134 135
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
  if (platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
  }
#endif
136
  return framework::OpKernelType(
A
Adam 已提交
137 138
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
      library);
T
tensor-tang 已提交
139 140 141
}

void FusionGRUOpMaker::Make() {
T
tensor-tang 已提交
142 143
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
144
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
145 146
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
T
tensor-tang 已提交
147 148 149 150 151
  AddInput("H0",
           "(Tensor, optional) The initial hidden state is an optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size, D is the hidden size.")
      .AsDispensable();
T
tensor-tang 已提交
152 153 154 155
  AddInput("WeightX",
           "(Tensor) The FC weight with shape (M x 3D),"
           "where M is the dim size of x, D is the hidden size. ");
  AddInput("WeightH",
T
tensor-tang 已提交
156 157 158 159 160
           "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
           "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
           "Acutally they are D x 2D and D x D two part weights."
           "{W_update, W_reset; W_state}"
           "{D x (D + D); D x D}");
T
tensor-tang 已提交
161
  AddInput("Bias",
T
tensor-tang 已提交
162 163 164
           "(Tensor, optional) (1 x 3D)."
           "Almost same as GRUOp."
           "Note: if have FC bias it should be added on this bias.")
T
tensor-tang 已提交
165
      .AsDispensable();
T
tensor-tang 已提交
166 167
  AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.")
      .AsIntermediate();
T
tensor-tang 已提交
168
  AddOutput("XX",
T
tensor-tang 已提交
169
            "(LoDTensor) the result after X * WeightX (size is T x 3D)"
T
tensor-tang 已提交
170 171 172
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size, M is the dim size of x input.")
T
tensor-tang 已提交
173
      .AsIntermediate();
T
tensor-tang 已提交
174 175 176 177
  AddOutput("BatchedInput",
            "(LoDTensor) This is the batched result of input X"
            "or the batched result after fc, shape (T x 3D)")
      .AsIntermediate();
T
tensor-tang 已提交
178
  AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
T
tensor-tang 已提交
179
      .AsIntermediate();
T
tensor-tang 已提交
180
  AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
T
tensor-tang 已提交
181 182 183 184 185 186 187 188 189 190
  AddAttr<std::string>("activation",
                       "(string, default tanh) "
                       "The activation type used for output candidate {h}_t.")
      .SetDefault("tanh");
  AddAttr<std::string>(
      "gate_activation",
      "(string, default sigmoid) "
      "The activation type used in update gate and reset gate.")
      .SetDefault("sigmoid");
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
191
                "(bool, default: False) "
T
tensor-tang 已提交
192 193
                "whether to compute reversed GRU.")
      .SetDefault(false);
T
tensor-tang 已提交
194
  AddAttr<bool>("use_seq",
翟飞跃 已提交
195
                "(bool, default: True) "
T
tensor-tang 已提交
196 197
                "whether to use seq mode to compute GRU.")
      .SetDefault(true);
A
Adam 已提交
198 199 200 201
  AddAttr<bool>("origin_mode",
                "bool"
                "use origin mode in article https://arxiv.org/abs/1412.3555")
      .SetDefault(false);
A
Adam 已提交
202 203 204
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
T
tensor-tang 已提交
205 206 207 208 209 210 211
  AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, 
more details can refer to GRU op.
)DOC");
}

T
tensor-tang 已提交
212
template <typename T>
T
tensor-tang 已提交
213 214
class FusionGRUKernel : public framework::OpKernel<T> {
 public:
T
tensor-tang 已提交
215
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
216
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
217 218 219 220 221 222
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }

T
tensor-tang 已提交
223 224 225 226 227 228 229 230 231 232
#define INIT_BASE_DEFINES                  \
  auto* x = ctx.Input<LoDTensor>("X");     \
  auto* wh = ctx.Input<Tensor>("WeightH"); \
  auto* xx = ctx.Output<LoDTensor>("XX");  \
  auto x_lod = x->lod();                   \
  auto x_dims = x->dims();   /* T x M*/    \
  auto wh_dims = wh->dims(); /* D x 3D*/   \
  const int total_T = x_dims[0];           \
  const int D3 = wh_dims[1]

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
#define INIT_OTHER_DEFINES                                                   \
  auto* h0 = ctx.Input<Tensor>("H0");                                        \
  auto* wx = ctx.Input<Tensor>("WeightX");                                   \
  auto* bias = ctx.Input<Tensor>("Bias");                                    \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                        \
  bool is_reverse = ctx.Attr<bool>("is_reverse");                            \
  const int M = x_dims[1];                                                   \
  const int D = wh_dims[0];                                                  \
  const int D2 = D * 2;                                                      \
  const jit::gru_attr_t attr(                                                \
      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("activation")));              \
  jit::gru_t one_step;                                                       \
  auto ComputeH1 =                                                           \
      jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At(  \
          attr);                                                             \
  auto ComputeHtPart1 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  auto ComputeHtPart2 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  const T* x_data = x->data<T>();                                            \
  const T* wx_data = wx->data<T>();                                          \
  const T* wh_data = wh->data<T>();                                          \
  auto place = ctx.GetPlace();                                               \
T
tensor-tang 已提交
259
  T* xx_data = xx->mutable_data<T>(place)
T
tensor-tang 已提交
260

T
tensor-tang 已提交
261 262
  void SeqCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
263 264
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
265
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
266
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
T
tensor-tang 已提交
267
    const T* wh_state_data = wh_data + D * D2;
T
tensor-tang 已提交
268
    T* hidden_out_data = hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
269
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
270 271 272 273 274

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::FCFunctor<DeviceContext, T> fc;
    fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
       bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291

    int xx_offset = D3;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 3;
      hidden_out_data = hidden_out_data + offset;
      xx_offset = -D3;
      gate_offset = -D;
    }
    auto move_step = [&]() {
      xx_data = xx_data + xx_offset;
      hidden_out_data = hidden_out_data + gate_offset;
    };
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
T
tensor-tang 已提交
292
      const T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
293 294 295 296
      int tstart = 0;
      if (h0_data) {
        prev_hidden_data = h0_data + bid * D;
      } else {
297 298
        one_step.gates = xx_data;
        one_step.ht = hidden_out_data;
299
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
300 301 302 303 304 305 306 307 308
        prev_hidden_data = hidden_out_data;
        tstart = 1;
        move_step();
      }
      for (int step = tstart; step < seq_len; ++step) {
        // gemm prev * (Wu + Wr)
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
                  prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
                  D3);
309 310 311
        one_step.gates = xx_data;
        one_step.ht_1 = prev_hidden_data;
        one_step.ht = hidden_out_data;
312
        ComputeHtPart1(&one_step, &attr);
T
tensor-tang 已提交
313 314 315 316
        // gemm rt * Ws
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
                  hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
                  xx_data + D2, D3);
317
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
318 319 320 321 322 323 324 325
        // save prev
        prev_hidden_data = hidden_out_data;
        move_step();
      }
    }
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
T
tensor-tang 已提交
326
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
327 328
    INIT_BASE_DEFINES;
    if (x_lod[0].size() == 2) {
329
      xx->Resize({total_T, D3});
T
tensor-tang 已提交
330 331 332
      SeqCompute(ctx);
      return;
    }
T
tensor-tang 已提交
333
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
334 335 336
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T
tensor-tang 已提交
337 338 339
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_out_data = batched_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
340 341 342
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
343 344

    math::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
345
    if (M > D3) {
346 347
      fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
348
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
349 350
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
351
      batched_input->set_lod(xx->lod());
352 353
      fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
354 355
    }

T
tensor-tang 已提交
356 357 358 359
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
T
tensor-tang 已提交
360

T
tensor-tang 已提交
361
    int tstart = 0;
T
tensor-tang 已提交
362
    T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
363
    if (h0) {
T
tensor-tang 已提交
364
      // reorder h0
T
tensor-tang 已提交
365
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T
tensor-tang 已提交
366 367 368 369 370 371 372
      const T* h0_data = h0->data<T>();
      prev_hidden_data = reordered_h0_data;
      size_t sz = sizeof(T) * D;
      for (int i = 0; i < max_bs; ++i) {
        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
        reordered_h0_data += D;
      }
T
tensor-tang 已提交
373
    } else {
T
tensor-tang 已提交
374 375 376 377 378
      // compute without h0
      T* cur_in_data = batched_input_data;
      T* cur_out_data = batched_out_data;
      // W: {W_update, W_reset; W_state}
      for (int i = 0; i < max_bs; ++i) {
379 380
        one_step.gates = cur_in_data;
        one_step.ht = cur_out_data;
381
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
382 383 384 385 386 387
        // add offset
        cur_in_data += D3;
        cur_out_data += D;
      }
      tstart = 1;
      prev_hidden_data = batched_out_data;
T
tensor-tang 已提交
388
    }
T
tensor-tang 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402
    // Then start from next
    const T* wh_state_data = wh_data + D * D2;
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
    batched_input_data = batched_input_data + tstart * max_bs * D3;
    batched_out_data = batched_out_data + tstart * max_bs * D;
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      // gemm prev * (Wu + Wr)
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast<T>(1),
                prev_hidden_data, D, wh_data, D2, static_cast<T>(1),
                batched_input_data, D3);

      T* cur_batched_data = batched_input_data;
403
      T* cur_out_data = batched_out_data;
T
tensor-tang 已提交
404 405
      T* cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
406 407 408
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
409
        ComputeHtPart1(&one_step, &attr);
410

T
tensor-tang 已提交
411 412
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
413
        cur_out_data += D;
T
tensor-tang 已提交
414 415
      }

T
tensor-tang 已提交
416
      cur_batched_data = batched_input_data;
417
      cur_out_data = batched_out_data;
T
tensor-tang 已提交
418
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
419
                cur_out_data, D, wh_state_data, D, static_cast<T>(1),
T
tensor-tang 已提交
420 421 422 423
                cur_batched_data + D2, D3);

      cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
424 425 426
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
427
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
428 429 430
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
        cur_out_data += D;
T
tensor-tang 已提交
431
      }
T
tensor-tang 已提交
432 433 434
      prev_hidden_data = batched_out_data;
      batched_out_data = cur_out_data;
      batched_input_data = cur_batched_data;
T
tensor-tang 已提交
435
    }
T
tensor-tang 已提交
436

T
tensor-tang 已提交
437
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
438 439
    batched_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_out, hidden_out);
T
tensor-tang 已提交
440
  }
T
tensor-tang 已提交
441 442
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
443 444 445 446 447 448
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
449 450
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);

T
tensor-tang 已提交
451 452
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
                       ops::FusionGRUKernel<double>);