gru_op.cu.cc 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gru_op.h"
G
guosheng 已提交
16

W
wanghuancoder 已提交
17 18 19 20 21 22 23
namespace paddle {
namespace platform {
class CUDADeviceContext;
struct CUDAPlace;
}  // namespace platform
}  // namespace paddle

24 25 26 27 28 29 30
namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
31 32 33
    using LodTensorPtr = LoDTensor*;

    bool is_test = context.Attr<bool>("is_test");
Q
Qiao Longfei 已提交
34
    bool origin_mode = context.Attr<bool>("origin_mode");
35 36 37 38 39 40 41 42
    auto* input = context.Input<LoDTensor>("Input");
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* bias = context.Input<Tensor>("Bias");
    auto* hidden = context.Output<LoDTensor>("Hidden");
    hidden->mutable_data<T>(context.GetPlace());

43
    auto input_dims = input->dims();
44 45
    auto hidden_dims = hidden->dims();

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
    LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
    if (is_test) {
      batch_gate = &batch_gate_tmp;
      batch_gate->Resize(input_dims);

      batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
      batch_reset_hidden_prev->Resize(hidden_dims);

      batch_hidden = &batch_hidden_tmp;
      batch_hidden->Resize(hidden_dims);
    } else {
      batch_gate = context.Output<LoDTensor>("BatchGate");
      batch_hidden = context.Output<LoDTensor>("BatchHidden");
      batch_reset_hidden_prev =
          context.Output<LoDTensor>("BatchResetHiddenPrev");
    }
    batch_gate->mutable_data<T>(context.GetPlace());
    batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
    batch_hidden->mutable_data<T>(context.GetPlace());

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
    bool is_reverse = context.Attr<bool>("is_reverse");
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
    auto& dev_ctx = context.template device_context<DeviceContext>();
    to_batch(dev_ctx, *input, batch_gate, true, is_reverse);

    if (bias) {
      math::RowwiseAdd<DeviceContext, T> add_bias;
      add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
    }

    int frame_size = hidden_dims[1];
    math::GRUMetaValue<T> gru_value;
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);
    Tensor ordered_h0;

    framework::Vector<size_t> order(batch_gate->lod()[2]);

    if (h0) {
      // Since the batch computing for GRU reorders the input sequences
      // according to their length. The initialized cell state also needs
      // to reorder.
      ReorderInitState<DeviceContext, T>(
          context.template device_context<DeviceContext>(), *h0, order,
          &ordered_h0, true);
      gru_value.prev_out_value = ordered_h0.data<T>();
    } else {
      gru_value.prev_out_value = nullptr;
    }
    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
    auto active_node = math::detail::GetActivationType(
        context.Attr<std::string>("activation"));
    auto active_gate = math::detail::GetActivationType(
        context.Attr<std::string>("gate_activation"));
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
      Tensor hidden_t = batch_hidden->Slice(bstart, bend);
      gru_value.output_value = hidden_t.data<T>();
      gru_value.gate_value = gate_t.data<T>();
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
      math::GRUUnitFunctor<DeviceContext, T>::compute(
          dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
Q
Qiao Longfei 已提交
116
          active_gate, origin_mode);
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
      gru_value.prev_out_value = gru_value.output_value;
    }

    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
    batch_hidden->set_lod(batch_gate->lod());
    to_seq(dev_ctx, *batch_hidden, hidden);
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

}  // namespace operators
}  // namespace paddle

G
guosheng 已提交
133
namespace ops = paddle::operators;
Q
QI JUN 已提交
134 135 136 137 138 139
REGISTER_OP_CUDA_KERNEL(
    gru, ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GRUKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    gru_grad, ops::GRUGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GRUGradKernel<paddle::platform::CUDADeviceContext, double>);