fusion_gru_op.cc 19.3 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
T
tensor-tang 已提交
16
#include <cstring>  // for memcpy
T
tensor-tang 已提交
17
#include <string>
18
#include "paddle/fluid/operators/jit/kernels.h"
T
tensor-tang 已提交
19
#include "paddle/fluid/operators/math/blas.h"
20
#include "paddle/fluid/operators/math/fc.h"
T
tensor-tang 已提交
21
#include "paddle/fluid/operators/math/sequence2batch.h"
A
Adam 已提交
22 23 24
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
T
tensor-tang 已提交
25 26 27 28 29

namespace paddle {
namespace operators {

void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
30 31 32 33 34
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
  OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
T
tensor-tang 已提交
35
  auto x_dims = ctx->GetInputDim("X");
36 37 38 39 40 41 42 43 44
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
                        ? framework::flatten_to_2d(x_dims, 1)
                        : x_dims;
  PADDLE_ENFORCE_EQ(
      x_mat_dims.size(), 2,
      platform::errors::InvalidArgument("The size of input X dims should be 2, "
                                        "or 3 with second dimension equal to "
                                        "1, but now Input X dim is:[%s] ",
                                        x_dims));
T
tensor-tang 已提交
45 46 47

  auto wx_dims = ctx->GetInputDim("WeightX");
  PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
48 49 50 51
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightX) should be 2, but received "
                        "WeightX dim size is:%d, WeightX dim is:[%s] ",
                        wx_dims.size(), wx_dims));
52 53 54 55 56 57 58 59
  PADDLE_ENFORCE_EQ(
      wx_dims[0], x_mat_dims[1],
      platform::errors::InvalidArgument(
          "The first dimension of flattened WeightX"
          "should equal to last dimension of flattened input X, but "
          "received fattened WeightX dimension is:%d, flattened X dimension "
          "is:%d",
          wx_dims[0], x_mat_dims[1]));
T
tensor-tang 已提交
60 61 62

  int frame_size = wx_dims[1] / 3;
  auto wh_dims = ctx->GetInputDim("WeightH");
63

T
tensor-tang 已提交
64
  PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
65 66 67 68
                    platform::errors::InvalidArgument(
                        "The rank of Input(WeightH) should be 2, but received "
                        "WeightH dim size is:%d, WeightH dim is:[%s]",
                        wh_dims.size(), wh_dims));
T
tensor-tang 已提交
69
  PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
70 71 72 73 74 75
                    platform::errors::InvalidArgument(
                        "The first dimension of WeightH "
                        "should equal to frame_size, but received WeightH's "
                        "first dimension is: "
                        "%d, frame size is:%d",
                        wh_dims[0], frame_size));
T
tensor-tang 已提交
76
  PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
77 78 79 80 81
                    platform::errors::InvalidArgument(
                        "The second dimension of Input(WeightH) "
                        "should equal to 3 * frame_size, but received WeightH "
                        "is:%d, frame size is:%d",
                        wh_dims[1], frame_size));
T
tensor-tang 已提交
82

83
  if (ctx->HasInput("H0")) {
T
tensor-tang 已提交
84 85
    auto h0_dims = ctx->GetInputDim("H0");
    PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
86 87 88 89
                      platform::errors::InvalidArgument(
                          "The width of H0 must be equal to frame_size, but "
                          "receiced the width of H0 is:%d, frame size is:%d",
                          h0_dims[1], frame_size));
T
tensor-tang 已提交
90
  }
91
  if (ctx->HasInput("Bias")) {
T
tensor-tang 已提交
92
    auto b_dims = ctx->GetInputDim("Bias");
93 94 95 96 97
    PADDLE_ENFORCE_EQ(b_dims.size(), 2,
                      platform::errors::InvalidArgument(
                          "The rank of Input(Bias) should be 2, but received "
                          "Bias rank is:%d, Bias dim is:[%s]",
                          b_dims.size(), b_dims));
T
tensor-tang 已提交
98
    PADDLE_ENFORCE_EQ(b_dims[0], 1,
99 100 101 102
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Bias) should be 1, but "
                          "received Bias first dim is:%d, Bias dim is:[%s]",
                          b_dims[0], b_dims));
T
tensor-tang 已提交
103
    PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
104 105 106 107
                      platform::errors::InvalidArgument(
                          "The shape of Bias must be [1, frame_size * 3], but "
                          "received bias dim is:[%s], frame size is:%d",
                          b_dims, frame_size));
T
tensor-tang 已提交
108
  }
109
  framework::DDim out_dims({x_mat_dims[0], frame_size});
T
tensor-tang 已提交
110 111
  ctx->SetOutputDim("Hidden", out_dims);
  ctx->ShareLoD("X", "Hidden");
T
tensor-tang 已提交
112
  int xx_width;
T
tensor-tang 已提交
113
  if (ctx->Attrs().Get<bool>("use_seq")) {
T
tensor-tang 已提交
114 115
    xx_width = wx_dims[1];
  } else {
116
    xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1];
117 118 119 120 121 122
    OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
                   "fusion_gru");
    OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
                   "fusion_gru");
123
    ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]});
T
tensor-tang 已提交
124
    ctx->SetOutputDim("BatchedOut", out_dims);
T
tensor-tang 已提交
125
  }
126
  ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width});
T
tensor-tang 已提交
127
  ctx->ShareLoD("X", "XX");
T
tensor-tang 已提交
128 129 130 131
}

framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
A
Adam 已提交
132 133 134 135 136 137 138 139
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
  if (platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
  }
#endif
140
  return framework::OpKernelType(
A
Adam 已提交
141 142
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
      library);
T
tensor-tang 已提交
143 144 145
}

void FusionGRUOpMaker::Make() {
T
tensor-tang 已提交
146 147
  AddInput("X",
           "(LoDTensor) the input is a LodTensor, which support "
T
tensor-tang 已提交
148
           "variable-time length input sequence. The underlying tensor in "
T
tensor-tang 已提交
149 150
           "this LoDTensor is a matrix with shape (T X M), where T is the "
           "total time steps in this mini-batch, M is the dim size of x.");
T
tensor-tang 已提交
151 152 153 154 155
  AddInput("H0",
           "(Tensor, optional) The initial hidden state is an optional "
           "input. This is a tensor with shape (N x D), where N is the "
           "batch size, D is the hidden size.")
      .AsDispensable();
T
tensor-tang 已提交
156 157 158 159
  AddInput("WeightX",
           "(Tensor) The FC weight with shape (M x 3D),"
           "where M is the dim size of x, D is the hidden size. ");
  AddInput("WeightH",
T
tensor-tang 已提交
160 161 162 163 164
           "(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
           "This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
           "Acutally they are D x 2D and D x D two part weights."
           "{W_update, W_reset; W_state}"
           "{D x (D + D); D x D}");
T
tensor-tang 已提交
165
  AddInput("Bias",
T
tensor-tang 已提交
166 167 168
           "(Tensor, optional) (1 x 3D)."
           "Almost same as GRUOp."
           "Note: if have FC bias it should be added on this bias.")
T
tensor-tang 已提交
169
      .AsDispensable();
T
tensor-tang 已提交
170 171
  AddOutput("ReorderedH0", "(Tensor) (N x D), which N is the min-batch size.")
      .AsIntermediate();
T
tensor-tang 已提交
172
  AddOutput("XX",
T
tensor-tang 已提交
173
            "(LoDTensor) the result after X * WeightX (size is T x 3D)"
T
tensor-tang 已提交
174 175 176
            " or batched_X (size is T x M), this will be automatically chosen,"
            " where T is the total time steps in this mini-batch,"
            " D is the hidden size, M is the dim size of x input.")
T
tensor-tang 已提交
177
      .AsIntermediate();
T
tensor-tang 已提交
178 179 180 181
  AddOutput("BatchedInput",
            "(LoDTensor) This is the batched result of input X"
            "or the batched result after fc, shape (T x 3D)")
      .AsIntermediate();
T
tensor-tang 已提交
182
  AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
T
tensor-tang 已提交
183
      .AsIntermediate();
T
tensor-tang 已提交
184
  AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
T
tensor-tang 已提交
185 186 187 188 189 190 191 192 193 194
  AddAttr<std::string>("activation",
                       "(string, default tanh) "
                       "The activation type used for output candidate {h}_t.")
      .SetDefault("tanh");
  AddAttr<std::string>(
      "gate_activation",
      "(string, default sigmoid) "
      "The activation type used in update gate and reset gate.")
      .SetDefault("sigmoid");
  AddAttr<bool>("is_reverse",
翟飞跃 已提交
195
                "(bool, default: False) "
T
tensor-tang 已提交
196 197
                "whether to compute reversed GRU.")
      .SetDefault(false);
T
tensor-tang 已提交
198
  AddAttr<bool>("use_seq",
翟飞跃 已提交
199
                "(bool, default: True) "
T
tensor-tang 已提交
200 201
                "whether to use seq mode to compute GRU.")
      .SetDefault(true);
A
Adam 已提交
202 203 204 205
  AddAttr<bool>("origin_mode",
                "bool"
                "use origin mode in article https://arxiv.org/abs/1412.3555")
      .SetDefault(false);
A
Adam 已提交
206 207 208
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
T
tensor-tang 已提交
209 210 211 212 213 214 215
  AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, 
more details can refer to GRU op.
)DOC");
}

T
tensor-tang 已提交
216
template <typename T>
T
tensor-tang 已提交
217 218
class FusionGRUKernel : public framework::OpKernel<T> {
 public:
T
tensor-tang 已提交
219
  void Compute(const framework::ExecutionContext& ctx) const override {
T
tensor-tang 已提交
220
    if (ctx.Attr<bool>("use_seq")) {
T
tensor-tang 已提交
221 222 223 224 225 226
      SeqCompute(ctx);
    } else {
      BatchCompute(ctx);
    }
  }

227 228 229 230 231 232 233 234 235 236 237
#define INIT_BASE_DEFINES                                     \
  auto* x = ctx.Input<LoDTensor>("X");                        \
  auto* wh = ctx.Input<Tensor>("WeightH");                    \
  auto* xx = ctx.Output<LoDTensor>("XX");                     \
  auto x_lod = x->lod();                                      \
  auto x_dims = x->dims(); /* T x M*/                         \
  auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)    \
                        ? framework::flatten_to_2d(x_dims, 1) \
                        : x_dims;                             \
  auto wh_dims = wh->dims(); /* D x 3D*/                      \
  const int total_T = x_mat_dims[0];                          \
T
tensor-tang 已提交
238 239
  const int D3 = wh_dims[1]

240 241 242 243 244 245
#define INIT_OTHER_DEFINES                                                   \
  auto* h0 = ctx.Input<Tensor>("H0");                                        \
  auto* wx = ctx.Input<Tensor>("WeightX");                                   \
  auto* bias = ctx.Input<Tensor>("Bias");                                    \
  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                        \
  bool is_reverse = ctx.Attr<bool>("is_reverse");                            \
246
  const int M = x_mat_dims[1];                                               \
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
  const int D = wh_dims[0];                                                  \
  const int D2 = D * 2;                                                      \
  const jit::gru_attr_t attr(                                                \
      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),       \
      jit::to_kerneltype(ctx.Attr<std::string>("activation")));              \
  jit::gru_t one_step;                                                       \
  auto ComputeH1 =                                                           \
      jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At(  \
          attr);                                                             \
  auto ComputeHtPart1 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  auto ComputeHtPart2 =                                                      \
      jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
          .At(attr);                                                         \
  const T* x_data = x->data<T>();                                            \
  const T* wx_data = wx->data<T>();                                          \
  const T* wh_data = wh->data<T>();                                          \
  auto place = ctx.GetPlace();                                               \
T
tensor-tang 已提交
266
  T* xx_data = xx->mutable_data<T>(place)
T
tensor-tang 已提交
267

T
tensor-tang 已提交
268 269
  void SeqCompute(const framework::ExecutionContext& ctx) const {
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
270 271
    INIT_BASE_DEFINES;
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
272
    const int N = x_lod[0].size() - 1;
T
tensor-tang 已提交
273
    const T* h0_data = h0 ? h0->data<T>() : nullptr;
T
tensor-tang 已提交
274
    const T* wh_state_data = wh_data + D * D2;
T
tensor-tang 已提交
275
    T* hidden_out_data = hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
276
    auto blas = math::GetBlas<DeviceContext, T>(ctx);
277 278 279 280 281

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    math::FCFunctor<DeviceContext, T> fc;
    fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
       bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298

    int xx_offset = D3;
    int gate_offset = D;
    if (is_reverse) {
      const int offset = (total_T - 1) * D;
      xx_data = xx_data + offset * 3;
      hidden_out_data = hidden_out_data + offset;
      xx_offset = -D3;
      gate_offset = -D;
    }
    auto move_step = [&]() {
      xx_data = xx_data + xx_offset;
      hidden_out_data = hidden_out_data + gate_offset;
    };
    for (int i = 0; i < N; ++i) {
      int bid = is_reverse ? N - 1 - i : i;
      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
T
tensor-tang 已提交
299
      const T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
300 301 302 303
      int tstart = 0;
      if (h0_data) {
        prev_hidden_data = h0_data + bid * D;
      } else {
304 305
        one_step.gates = xx_data;
        one_step.ht = hidden_out_data;
306
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
307 308 309 310 311 312 313 314 315
        prev_hidden_data = hidden_out_data;
        tstart = 1;
        move_step();
      }
      for (int step = tstart; step < seq_len; ++step) {
        // gemm prev * (Wu + Wr)
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
                  prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
                  D3);
316 317 318
        one_step.gates = xx_data;
        one_step.ht_1 = prev_hidden_data;
        one_step.ht = hidden_out_data;
319
        ComputeHtPart1(&one_step, &attr);
T
tensor-tang 已提交
320 321 322 323
        // gemm rt * Ws
        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
                  hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
                  xx_data + D2, D3);
324
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
325 326 327 328 329 330 331 332
        // save prev
        prev_hidden_data = hidden_out_data;
        move_step();
      }
    }
  }

  void BatchCompute(const framework::ExecutionContext& ctx) const {
T
tensor-tang 已提交
333
    using DeviceContext = paddle::platform::CPUDeviceContext;
T
tensor-tang 已提交
334 335
    INIT_BASE_DEFINES;
    if (x_lod[0].size() == 2) {
336
      xx->Resize({total_T, D3});
T
tensor-tang 已提交
337 338 339
      SeqCompute(ctx);
      return;
    }
T
tensor-tang 已提交
340
    INIT_OTHER_DEFINES;
T
tensor-tang 已提交
341 342 343
    auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
    auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
    auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T
tensor-tang 已提交
344 345 346
    T* batched_input_data = batched_input->mutable_data<T>(place);
    T* batched_out_data = batched_out->mutable_data<T>(place);
    hidden_out->mutable_data<T>(place);
T
tensor-tang 已提交
347 348 349
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
350 351

    math::FCFunctor<DeviceContext, T> fc;
T
tensor-tang 已提交
352
    if (M > D3) {
353 354
      fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
355
      to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
T
tensor-tang 已提交
356 357
    } else {
      to_batch(dev_ctx, *x, xx, true, is_reverse);
T
tensor-tang 已提交
358
      batched_input->set_lod(xx->lod());
359 360
      fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
         bias ? bias->data<T>() : nullptr);
T
tensor-tang 已提交
361 362
    }

T
tensor-tang 已提交
363 364 365 366
    auto batched_lod = batched_input->lod();
    const auto& seq_order = batched_lod[2];
    const int max_bs = seq_order.size();
    reordered_h0->Resize({max_bs, D});
T
tensor-tang 已提交
367

T
tensor-tang 已提交
368
    int tstart = 0;
T
tensor-tang 已提交
369
    T* prev_hidden_data = nullptr;
T
tensor-tang 已提交
370
    if (h0) {
T
tensor-tang 已提交
371
      // reorder h0
T
tensor-tang 已提交
372
      T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T
tensor-tang 已提交
373 374 375 376 377 378 379
      const T* h0_data = h0->data<T>();
      prev_hidden_data = reordered_h0_data;
      size_t sz = sizeof(T) * D;
      for (int i = 0; i < max_bs; ++i) {
        std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
        reordered_h0_data += D;
      }
T
tensor-tang 已提交
380
    } else {
T
tensor-tang 已提交
381 382 383 384 385
      // compute without h0
      T* cur_in_data = batched_input_data;
      T* cur_out_data = batched_out_data;
      // W: {W_update, W_reset; W_state}
      for (int i = 0; i < max_bs; ++i) {
386 387
        one_step.gates = cur_in_data;
        one_step.ht = cur_out_data;
388
        ComputeH1(&one_step, &attr);
T
tensor-tang 已提交
389 390 391 392 393 394
        // add offset
        cur_in_data += D3;
        cur_out_data += D;
      }
      tstart = 1;
      prev_hidden_data = batched_out_data;
T
tensor-tang 已提交
395
    }
T
tensor-tang 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409
    // Then start from next
    const T* wh_state_data = wh_data + D * D2;
    const auto& batch_starts = batched_lod[0];
    const int max_seq_len = batch_starts.size() - 1;
    batched_input_data = batched_input_data + tstart * max_bs * D3;
    batched_out_data = batched_out_data + tstart * max_bs * D;
    for (int step = tstart; step < max_seq_len; ++step) {
      const int cur_bs = batch_starts[step + 1] - batch_starts[step];
      // gemm prev * (Wu + Wr)
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D2, D, static_cast<T>(1),
                prev_hidden_data, D, wh_data, D2, static_cast<T>(1),
                batched_input_data, D3);

      T* cur_batched_data = batched_input_data;
410
      T* cur_out_data = batched_out_data;
T
tensor-tang 已提交
411 412
      T* cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
413 414 415
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
416
        ComputeHtPart1(&one_step, &attr);
417

T
tensor-tang 已提交
418 419
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
420
        cur_out_data += D;
T
tensor-tang 已提交
421 422
      }

T
tensor-tang 已提交
423
      cur_batched_data = batched_input_data;
424
      cur_out_data = batched_out_data;
T
tensor-tang 已提交
425
      blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
426
                cur_out_data, D, wh_state_data, D, static_cast<T>(1),
T
tensor-tang 已提交
427 428 429 430
                cur_batched_data + D2, D3);

      cur_prev_hidden_data = prev_hidden_data;
      for (int i = 0; i < cur_bs; ++i) {
431 432 433
        one_step.gates = cur_batched_data;
        one_step.ht_1 = cur_prev_hidden_data;
        one_step.ht = cur_out_data;
434
        ComputeHtPart2(&one_step, &attr);
T
tensor-tang 已提交
435 436 437
        cur_batched_data += D3;
        cur_prev_hidden_data += D;
        cur_out_data += D;
T
tensor-tang 已提交
438
      }
T
tensor-tang 已提交
439 440 441
      prev_hidden_data = batched_out_data;
      batched_out_data = cur_out_data;
      batched_input_data = cur_batched_data;
T
tensor-tang 已提交
442
    }
T
tensor-tang 已提交
443

T
tensor-tang 已提交
444
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
T
tensor-tang 已提交
445 446
    batched_out->set_lod(batched_lod);
    to_seq(dev_ctx, *batched_out, hidden_out);
T
tensor-tang 已提交
447
  }
T
tensor-tang 已提交
448 449
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
T
tensor-tang 已提交
450 451 452 453 454 455
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
456 457
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);

T
tensor-tang 已提交
458 459
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
                       ops::FusionGRUKernel<double>);