gru_op.cu.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gru_op.h"
G
guosheng 已提交
16

W
wanghuancoder 已提交
17 18 19
namespace paddle {
namespace platform {
class CUDADeviceContext;
20

W
wanghuancoder 已提交
21 22 23
}  // namespace platform
}  // namespace paddle

24 25 26 27 28 29 30
namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
31 32 33
    using LodTensorPtr = LoDTensor*;

    bool is_test = context.Attr<bool>("is_test");
Q
Qiao Longfei 已提交
34
    bool origin_mode = context.Attr<bool>("origin_mode");
35 36 37 38 39 40 41 42
    auto* input = context.Input<LoDTensor>("Input");
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* bias = context.Input<Tensor>("Bias");
    auto* hidden = context.Output<LoDTensor>("Hidden");
    hidden->mutable_data<T>(context.GetPlace());

43
    auto input_dims = input->dims();
44 45
    auto hidden_dims = hidden->dims();

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
    LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
    if (is_test) {
      batch_gate = &batch_gate_tmp;
      batch_gate->Resize(input_dims);

      batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp;
      batch_reset_hidden_prev->Resize(hidden_dims);

      batch_hidden = &batch_hidden_tmp;
      batch_hidden->Resize(hidden_dims);
    } else {
      batch_gate = context.Output<LoDTensor>("BatchGate");
      batch_hidden = context.Output<LoDTensor>("BatchHidden");
      batch_reset_hidden_prev =
          context.Output<LoDTensor>("BatchResetHiddenPrev");
    }
    batch_gate->mutable_data<T>(context.GetPlace());
    batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
    batch_hidden->mutable_data<T>(context.GetPlace());

67
    bool is_reverse = context.Attr<bool>("is_reverse");
F
Feiyu Chan 已提交
68
    phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
69 70 71 72
    auto& dev_ctx = context.template device_context<DeviceContext>();
    to_batch(dev_ctx, *input, batch_gate, true, is_reverse);

    if (bias) {
73
      phi::funcs::RowwiseAdd<DeviceContext, T> add_bias;
74 75 76 77
      add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
    }

    int frame_size = hidden_dims[1];
F
Feiyu Chan 已提交
78
    phi::funcs::GRUMetaValue<T> gru_value;
79 80 81 82 83 84 85 86 87 88 89 90
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);
    Tensor ordered_h0;

    framework::Vector<size_t> order(batch_gate->lod()[2]);

    if (h0) {
      // Since the batch computing for GRU reorders the input sequences
      // according to their length. The initialized cell state also needs
      // to reorder.
      ReorderInitState<DeviceContext, T>(
91 92 93 94 95
          context.template device_context<DeviceContext>(),
          *h0,
          order,
          &ordered_h0,
          true);
96 97 98 99 100 101
      gru_value.prev_out_value = ordered_h0.data<T>();
    } else {
      gru_value.prev_out_value = nullptr;
    }
    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
F
Feiyu Chan 已提交
102
    auto active_node = phi::funcs::detail::GetActivationType(
103
        context.Attr<std::string>("activation"));
F
Feiyu Chan 已提交
104
    auto active_gate = phi::funcs::detail::GetActivationType(
105 106 107 108 109 110 111 112 113 114 115 116
        context.Attr<std::string>("gate_activation"));
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
      Tensor hidden_t = batch_hidden->Slice(bstart, bend);
      gru_value.output_value = hidden_t.data<T>();
      gru_value.gate_value = gate_t.data<T>();
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
117 118 119 120 121 122 123
      phi::funcs::GRUUnitFunctor<DeviceContext, T>::compute(dev_ctx,
                                                            gru_value,
                                                            frame_size,
                                                            cur_batch_size,
                                                            active_node,
                                                            active_gate,
                                                            origin_mode);
124 125 126
      gru_value.prev_out_value = gru_value.output_value;
    }

F
Feiyu Chan 已提交
127
    phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
128 129 130 131 132 133 134 135 136 137 138 139
    batch_hidden->set_lod(batch_gate->lod());
    to_seq(dev_ctx, *batch_hidden, hidden);
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

}  // namespace operators
}  // namespace paddle

G
guosheng 已提交
140
namespace ops = paddle::operators;
Q
QI JUN 已提交
141
REGISTER_OP_CUDA_KERNEL(
142 143
    gru,
    ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,
Q
QI JUN 已提交
144 145
    ops::GRUKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
146 147
    gru_grad,
    ops::GRUGradKernel<paddle::platform::CUDADeviceContext, float>,
Q
QI JUN 已提交
148
    ops::GRUGradKernel<paddle::platform::CUDADeviceContext, double>);