gru_op.h 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14 15

#pragma once
16 17 18
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
T
tensor-tang 已提交
19
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/math/detail/activation_functions.h"
T
tensor-tang 已提交
21 22
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
Y
Yi Wang 已提交
23 24 25
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
G
guosheng 已提交
26 27 28 29

namespace paddle {
namespace operators {

G
guosheng 已提交
30 31 32
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

Q
QI JUN 已提交
33 34
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
D
dzhwinter 已提交
35 36
                             const framework::Tensor& src,
                             framework::Vector<size_t> index_lod,
G
guosheng 已提交
37
                             framework::Tensor* dst, bool indexed_src) {
Q
QI JUN 已提交
38
  math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
G
guosheng 已提交
39
  dst->mutable_data<T>(src.dims(), ctx.GetPlace());
40
  row_shuffle(ctx, src, index_lod, dst, indexed_src);
G
guosheng 已提交
41 42
}

Q
QI JUN 已提交
43
template <typename DeviceContext, typename T>
G
guosheng 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
class GRUGradKernel : public framework::OpKernel<T> {
 public:
  void BatchCompute(const framework::ExecutionContext& context) const {
    auto* h0 = context.Input<Tensor>("H0");
    auto* weight = context.Input<Tensor>("Weight");
    const T* weight_data = weight->data<T>();
    auto* batch_gate = context.Input<LoDTensor>("BatchGate");
    auto* batch_reset_hidden_prev =
        context.Input<LoDTensor>("BatchResetHiddenPrev");
    auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
    auto* hidden = context.Input<LoDTensor>("Hidden");
    auto* hidden_grad =
        context.Input<LoDTensor>(framework::GradVarName("Hidden"));
    auto* input_grad =
        context.Output<LoDTensor>(framework::GradVarName("Input"));
    auto* h0_grad = context.Output<Tensor>(framework::GradVarName("H0"));
    auto* weight_grad =
        context.Output<Tensor>(framework::GradVarName("Weight"));
    auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));

    auto gate_dims = batch_gate->dims();
    auto hidden_dims = hidden->dims();
    int frame_size = hidden_dims[1];

Q
QI JUN 已提交
68
    math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
G
guosheng 已提交
69 70 71 72 73
    LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
    batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
    batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
    batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
                                                 context.GetPlace());
Q
QI JUN 已提交
74 75
    math::SetConstant<DeviceContext, T> zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
76 77 78
    zero(dev_ctx, &batch_hidden_grad, static_cast<T>(0.0));
    zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
    zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
G
guosheng 已提交
79

G
guosheng 已提交
80
    Tensor ordered_h0, ordered_h0_grad;
D
dzhwinter 已提交
81 82 83

    framework::Vector<size_t> order(batch_gate->lod()[2]);

G
guosheng 已提交
84
    if (h0) {
Q
QI JUN 已提交
85 86
      ReorderInitState<DeviceContext, T>(dev_ctx, *h0, order, &ordered_h0,
                                         true);
G
guosheng 已提交
87 88 89
    }
    if (h0_grad) {
      ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
Q
QI JUN 已提交
90 91
      zero(context.template device_context<DeviceContext>(), &ordered_h0_grad,
           static_cast<T>(0.0));
G
guosheng 已提交
92 93
    }

G
guosheng 已提交
94 95
    bool is_reverse = context.Attr<bool>("is_reverse");
    batch_hidden_grad.set_lod(batch_hidden->lod());
96
    to_batch(dev_ctx, *hidden_grad, &batch_hidden_grad, false, is_reverse);
G
guosheng 已提交
97

98
    math::GRUMetaValue<T> gru_value;
G
guosheng 已提交
99 100
    gru_value.gate_weight = const_cast<T*>(weight_data);
    gru_value.state_weight =
G
guosheng 已提交
101 102
        const_cast<T*>(weight_data + 2 * frame_size * frame_size);

103
    math::GRUMetaGrad<T> gru_grad;
G
guosheng 已提交
104
    if (weight_grad) {
G
guosheng 已提交
105
      gru_grad.gate_weight_grad =
G
guosheng 已提交
106
          weight_grad->mutable_data<T>(context.GetPlace());
107
      zero(dev_ctx, weight_grad, static_cast<T>(0.0));
G
guosheng 已提交
108
      gru_grad.state_weight_grad =
G
guosheng 已提交
109 110
          weight_grad->data<T>() + 2 * frame_size * frame_size;
    } else {
G
guosheng 已提交
111 112
      gru_grad.gate_weight_grad = nullptr;
      gru_grad.state_weight_grad = nullptr;
G
guosheng 已提交
113 114 115 116
    }

    auto batch_starts = batch_hidden_grad.lod()[0];
    size_t num_batch = batch_starts.size() - 1;
117 118 119 120
    auto active_node = math::detail::GetActivationType(
        context.Attr<std::string>("activation"));
    auto active_gate = math::detail::GetActivationType(
        context.Attr<std::string>("gate_activation"));
G
guosheng 已提交
121 122 123 124 125 126
    for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
      int bstart = static_cast<int>(batch_starts[n]);
      int bend = static_cast<int>(batch_starts[n + 1]);
      int cur_batch_size = bend - bstart;

      Tensor gate_t = batch_gate->Slice(bstart, bend);
G
guosheng 已提交
127
      gru_value.gate_value = gate_t.data<T>();
G
guosheng 已提交
128
      Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
G
guosheng 已提交
129
      gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
G
guosheng 已提交
130 131

      Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
G
guosheng 已提交
132
      gru_grad.output_grad = hidden_grad_t.data<T>();
G
guosheng 已提交
133
      Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
G
guosheng 已提交
134
      gru_grad.gate_grad = gate_grad_t.data<T>();
G
guosheng 已提交
135 136
      Tensor reset_hidden_prev_grad_t =
          batch_reset_hidden_prev_grad.Slice(bstart, bend);
G
guosheng 已提交
137
      gru_grad.reset_output_grad = reset_hidden_prev_grad_t.data<T>();
G
guosheng 已提交
138
      if (n == 0) {
G
guosheng 已提交
139 140
        gru_value.prev_out_value = h0 ? ordered_h0.data<T>() : nullptr;
        gru_grad.prev_out_grad =
G
guosheng 已提交
141
            h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
G
guosheng 已提交
142 143 144
      } else {
        int bstart_pre = static_cast<int>(batch_starts[n - 1]);
        Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
G
guosheng 已提交
145
        gru_value.prev_out_value = hidden_prev_t.data<T>();
G
guosheng 已提交
146
        Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
G
guosheng 已提交
147
        gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
G
guosheng 已提交
148 149
      }

Q
QI JUN 已提交
150
      math::GRUUnitGradFunctor<DeviceContext, T>::compute(
151 152
          dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
          active_gate);
G
guosheng 已提交
153 154 155
    }
    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
156
      math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
G
guosheng 已提交
157
      batch_gate_grad.set_lod(batch_gate->lod());
158
      to_seq(dev_ctx, batch_gate_grad, input_grad);
G
guosheng 已提交
159 160 161
    }
    if (bias_grad) {
      bias_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
162
      math::ColwiseSum<DeviceContext, T> col_sum;
163
      col_sum(dev_ctx, batch_gate_grad, bias_grad);
G
guosheng 已提交
164
    }
G
guosheng 已提交
165
    if (h0 && h0_grad) {
Q
QI JUN 已提交
166 167
      ReorderInitState<DeviceContext, T>(dev_ctx, ordered_h0_grad, order,
                                         h0_grad, false);
G
guosheng 已提交
168
    }
G
guosheng 已提交
169 170 171 172 173 174 175 176 177
  }

  void Compute(const framework::ExecutionContext& context) const override {
    BatchCompute(context);
  }
};

}  // namespace operators
}  // namespace paddle