gru_op.h 10.2 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

G
guosheng 已提交
27 28 29
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

Q
QI JUN 已提交
30 31
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
G
guosheng 已提交
32 33
                             const framework::Tensor& src, const size_t* index,
                             framework::Tensor* dst, bool indexed_src) {
Q
QI JUN 已提交
34
  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
G
guosheng 已提交
35 36 37 38
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
  row_shuffle(ctx, src, index, *dst, indexed_src);
}

Q
QI JUN 已提交
39
template <typename DeviceContext, typename T>
G
guosheng 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class GRUKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
    auto* input = context.Input<LoDTensor>("Input");
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* bias = context.Input<Tensor>("Bias");
    auto* batch_gate = context.Output<LoDTensor>("BatchGate");
    batch_gate->mutable_data<T>(context.GetPlace());
    auto* batch_reset_hidden_prev =
        context.Output<LoDTensor>("BatchResetHiddenPrev");
    batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
    auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
    batch_hidden->mutable_data<T>(context.GetPlace());
    auto* hidden = context.Output<LoDTensor>("Hidden");
    hidden->mutable_data<T>(context.GetPlace());

    context.ShareLoD("Input", "Hidden");

    auto hidden_dims = hidden->dims();

    bool is_reverse = context.Attr<bool>("is_reverse");
Q
QI JUN 已提交
63 64
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
    auto& dev_ctx = context.template device_context<DeviceContext>();
65
    to_batch(dev_ctx, *input, *batch_gate, true, is_reverse);
G
guosheng 已提交
66 67

    if (bias) {
Q
QI JUN 已提交
68
      math::RowwiseAdd<DeviceContext, T> add_bias;
69
      add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
G
guosheng 已提交
70 71
    }

72
    int frame_size = hidden_dims[1];
G
guosheng 已提交
73
    math::hl_gru_value<T> gru_value;
G
guosheng 已提交
74 75
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
G
guosheng 已提交
76
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);
G
guosheng 已提交
77 78 79 80 81 82
    Tensor ordered_h0;
    const size_t* order = batch_gate->lod()[2].data();
    if (h0) {
      // Since the batch computing for GRU reorders the input sequences
      // according to their length. The initialized cell state also needs
      // to reorder.
Q
QI JUN 已提交
83 84 85
      ReorderInitState<DeviceContext, T>(
          context.template device_context<DeviceContext>(), *h0, order,
          &ordered_h0, true);
G
guosheng 已提交
86
      gru_value.prev_out_value = ordered_h0.data<T>();
G
guosheng 已提交
87
    } else {
G
guosheng 已提交
88
      gru_value.prev_out_value = nullptr;
G
guosheng 已提交
89
    }
G
guosheng 已提交
90 91 92 93 94 95 96 97 98 99
    auto batch_starts = batch_gate->lod()[0];
    size_t num_batch = batch_starts.size() - 1;
    for (size_t n = 0; n < num_batch; n++) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
      Tensor hidden_t = batch_hidden->Slice(bstart, bend);
G
guosheng 已提交
100 101 102
      gru_value.output_value = hidden_t.data<T>();
      gru_value.gate_value = gate_t.data<T>();
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
Q
QI JUN 已提交
103
      math::GRUUnitFunctor<DeviceContext, T>::compute(
104
          dev_ctx, gru_value, frame_size, cur_batch_size,
G
guosheng 已提交
105 106
          math::ActiveType(context.Attr<std::string>("activation")),
          math::ActiveType(context.Attr<std::string>("gate_activation")));
G
guosheng 已提交
107
      gru_value.prev_out_value = gru_value.output_value;
G
guosheng 已提交
108 109
    }

Q
QI JUN 已提交
110
    math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
G
guosheng 已提交
111
    batch_hidden->set_lod(batch_gate->lod());
112
    to_seq(dev_ctx, *batch_hidden, *hidden);
G
guosheng 已提交
113 114 115 116 117 118 119
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

Q
QI JUN 已提交
120
template <typename DeviceContext, typename T>
G
guosheng 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
class GRUGradKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* batch_gate = context.Input<LoDTensor>("BatchGate");
    auto* batch_reset_hidden_prev =
        context.Input<LoDTensor>("BatchResetHiddenPrev");
    auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
    auto* hidden = context.Input<LoDTensor>("Hidden");
    auto* hidden_grad =
        context.Input<LoDTensor>(framework::GradVarName("Hidden"));
    auto* input_grad =
        context.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* h0_grad = context.Output<Tensor>(framework::GradVarName("H0"));
    auto* weight_grad =
        context.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));

    auto gate_dims = batch_gate->dims();
    auto hidden_dims = hidden->dims();
    int frame_size = hidden_dims[1];

Q
QI JUN 已提交
145
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
G
guosheng 已提交
146 147 148 149 150
    LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
    batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
    batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
    batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
                                                 context.GetPlace());
Q
QI JUN 已提交
151 152
    math::SetConstant<DeviceContext, T> zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
153 154 155
    zero(dev_ctx, &batch_hidden_grad, static_cast<T>(0.0));
    zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
    zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
G
guosheng 已提交
156

G
guosheng 已提交
157 158 159
    Tensor ordered_h0, ordered_h0_grad;
    const size_t* order = batch_gate->lod()[2].data();
    if (h0) {
Q
QI JUN 已提交
160 161
      ReorderInitState<DeviceContext, T>(dev_ctx, *h0, order, &ordered_h0,
                                         true);
G
guosheng 已提交
162 163 164
    }
    if (h0_grad) {
      ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
Q
QI JUN 已提交
165 166
      zero(context.template device_context<DeviceContext>(), &ordered_h0_grad,
           static_cast<T>(0.0));
G
guosheng 已提交
167 168
    }

G
guosheng 已提交
169 170
    bool is_reverse = context.Attr<bool>("is_reverse");
    batch_hidden_grad.set_lod(batch_hidden->lod());
171
    to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
G
guosheng 已提交
172 173

    math::hl_gru_value<T> gru_value;
G
guosheng 已提交
174 175
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
G
guosheng 已提交
176 177 178 179
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);

    math::hl_gru_grad<T> gru_grad;
    if (weight_grad) {
G
guosheng 已提交
180
      gru_grad.gate_weight_grad =
G
guosheng 已提交
181
          weight_grad->mutable_data<T>(context.GetPlace());
182
      zero(dev_ctx, weight_grad, static_cast<T>(0.0));
G
guosheng 已提交
183
      gru_grad.state_weight_grad =
G
guosheng 已提交
184 185
          weight_grad->data<T>() + 2 * frame_size * frame_size;
    } else {
G
guosheng 已提交
186 187
      gru_grad.gate_weight_grad = nullptr;
      gru_grad.state_weight_grad = nullptr;
G
guosheng 已提交
188 189 190 191 192 193 194 195 196 197
    }

    auto batch_starts = batch_hidden_grad.lod()[0];
    size_t num_batch = batch_starts.size() - 1;
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
G
guosheng 已提交
198
      gru_value.gate_value = gate_t.data<T>();
G
guosheng 已提交
199
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
G
guosheng 已提交
200
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
G
guosheng 已提交
201 202

      Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
G
guosheng 已提交
203
      gru_grad.output_grad = hidden_grad_t.data<T>();
G
guosheng 已提交
204
      Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
G
guosheng 已提交
205
      gru_grad.gate_grad = gate_grad_t.data<T>();
G
guosheng 已提交
206 207
      Tensor reset_hidden_prev_grad_t =
          batch_reset_hidden_prev_grad.Slice(bstart, bend);
G
guosheng 已提交
208
      gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data<T>();
G
guosheng 已提交
209
      if (n == 0) {
G
guosheng 已提交
210 211
        gru_value.prev_out_value = h0 ? ordered_h0.data<T>() : nullptr;
        gru_grad.prev_out_grad =
G
guosheng 已提交
212
            h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
G
guosheng 已提交
213 214 215
      } else {
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
G
guosheng 已提交
216
        gru_value.prev_out_value = hidden_prev_t.data<T>();
G
guosheng 已提交
217
        Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
G
guosheng 已提交
218
        gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
G
guosheng 已提交
219 220
      }

Q
QI JUN 已提交
221
      math::GRUUnitGradFunctor<DeviceContext, T>::compute(
222
          dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
G
guosheng 已提交
223 224 225 226 227
          math::ActiveType(context.Attr<std::string>("activation")),
          math::ActiveType(context.Attr<std::string>("gate_activation")));
    }
    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
228
      math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
G
guosheng 已提交
229
      batch_gate_grad.set_lod(batch_gate->lod());
230
      to_seq(dev_ctx, batch_gate_grad, *input_grad);
G
guosheng 已提交
231 232 233
    }
    if (bias_grad) {
      bias_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
234
      math::ColwiseSum<DeviceContext, T> col_sum;
235
      col_sum(dev_ctx, batch_gate_grad, bias_grad);
G
guosheng 已提交
236
    }
G
guosheng 已提交
237
    if (h0 && h0_grad) {
Q
QI JUN 已提交
238 239
      ReorderInitState<DeviceContext, T>(dev_ctx, ordered_h0_grad, order,
                                         h0_grad, false);
G
guosheng 已提交
240
    }
G
guosheng 已提交
241 242 243 244 245 246 247 248 249
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

}  // namespace operators
}  // namespace paddle