gru_op.cu.cc 5.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gru_op.h"
G
guosheng 已提交
16

17 18 19 20 21 22 23
namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
24 25 26
    using LodTensorPtr = LoDTensor*;

    bool is_test = context.Attr<bool>("is_test");
Q
Qiao Longfei 已提交
27
    bool origin_mode = context.Attr<bool>("origin_mode");
28 29 30 31 32 33 34 35
    auto* input = context.Input<LoDTensor>("Input");
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* bias = context.Input<Tensor>("Bias");
    auto* hidden = context.Output<LoDTensor>("Hidden");
    hidden->mutable_data<T>(context.GetPlace());

36
    auto input_dims = input->dims();
37 38
    auto hidden_dims = hidden->dims();

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
    LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
    if (is_test) {
      batch_gate = &batch_gate_tmp;
      batch_gate->Resize(input_dims);

      batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
      batch_reset_hidden_prev->Resize(hidden_dims);

      batch_hidden = &batch_hidden_tmp;
      batch_hidden->Resize(hidden_dims);
    } else {
      batch_gate = context.Output<LoDTensor>("BatchGate");
      batch_hidden = context.Output<LoDTensor>("BatchHidden");
      batch_reset_hidden_prev =
          context.Output<LoDTensor>("BatchResetHiddenPrev");
    }
    batch_gate->mutable_data<T>(context.GetPlace());
    batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
    batch_hidden->mutable_data<T>(context.GetPlace());

60
    bool is_reverse = context.Attr<bool>("is_reverse");
F
Feiyu Chan 已提交
61
    phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
62 63 64 65
    auto& dev_ctx = context.template device_context<DeviceContext>();
    to_batch(dev_ctx, *input, batch_gate, true, is_reverse);

    if (bias) {
66
      phi::funcs::RowwiseAdd<DeviceContext, T> add_bias;
67 68 69 70
      add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
    }

    int frame_size = hidden_dims[1];
F
Feiyu Chan 已提交
71
    phi::funcs::GRUMetaValue<T> gru_value;
72 73 74 75 76 77 78 79 80 81 82 83
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);
    Tensor ordered_h0;

    framework::Vector<size_t> order(batch_gate->lod()[2]);

    if (h0) {
      // Since the batch computing for GRU reorders the input sequences
      // according to their length. The initialized cell state also needs
      // to reorder.
      ReorderInitState<DeviceContext, T>(
84 85 86 87 88
          context.template device_context<DeviceContext>(),
          *h0,
          order,
          &ordered_h0,
          true);
89 90 91 92 93 94
      gru_value.prev_out_value = ordered_h0.data<T>();
    } else {
      gru_value.prev_out_value = nullptr;
    }
    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
F
Feiyu Chan 已提交
95
    auto active_node = phi::funcs::detail::GetActivationType(
96
        context.Attr<std::string>("activation"));
F
Feiyu Chan 已提交
97
    auto active_gate = phi::funcs::detail::GetActivationType(
98 99 100 101 102 103 104 105 106 107 108 109
        context.Attr<std::string>("gate_activation"));
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
      Tensor hidden_t = batch_hidden->Slice(bstart, bend);
      gru_value.output_value = hidden_t.data<T>();
      gru_value.gate_value = gate_t.data<T>();
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
110 111 112 113 114 115 116
      phi::funcs::GRUUnitFunctor<DeviceContext, T>::compute(dev_ctx,
                                                            gru_value,
                                                            frame_size,
                                                            cur_batch_size,
                                                            active_node,
                                                            active_gate,
                                                            origin_mode);
117 118 119
      gru_value.prev_out_value = gru_value.output_value;
    }

F
Feiyu Chan 已提交
120
    phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
121 122 123 124 125 126 127 128 129 130 131 132
    batch_hidden->set_lod(batch_gate->lod());
    to_seq(dev_ctx, *batch_hidden, hidden);
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

}  // namespace operators
}  // namespace paddle

G
guosheng 已提交
133
namespace ops = paddle::operators;
Q
QI JUN 已提交
134
REGISTER_OP_CUDA_KERNEL(
135 136
    gru,
    ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,
Q
QI JUN 已提交
137 138
    ops::GRUKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
139 140
    gru_grad,
    ops::GRUGradKernel<paddle::platform::CUDADeviceContext, float>,
Q
QI JUN 已提交
141
    ops::GRUGradKernel<paddle::platform::CUDADeviceContext, double>);