lstm_unit_op.cu 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Z
zchen0211 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Z
zchen0211 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Z
zchen0211 已提交
14

15 16 17 18
/* Acknowledgement: the following code is strongly inspired by
https://github.com/caffe2/caffe2/blob/master/caffe2/operators/lstm_unit_op_gpu.cu
*/

19
#include "paddle/fluid/operators/lstm_unit_op.h"
Y
Yi Wang 已提交
20 21
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cross_entropy_op.h"
22
#include "paddle/phi/core/hostdevice.h"
Z
zchen0211 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace operators {

template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
  return Dtype(1) / (Dtype(1) + exp(-x));
}

template <typename Dtype>
__device__ Dtype cuda_tanh(const Dtype x) {
  return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
}

template <typename T>
38 39 40 41 42 43
__global__ void LSTMUnitKernel(const int nthreads,
                               const int dim,
                               const T* C_prev,
                               const T* X,
                               T* C,
                               T* H,
Z
zchen0211 已提交
44
                               const T forget_bias) {
45
  CUDA_KERNEL_LOOP(index, nthreads) {
Z
zchen0211 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    const int n = index / dim;
    const int d = index % dim;

    const T* X_offset = X + 4 * dim * n;
    const T i = cuda_sigmoid(X_offset[d]);
    const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
    const T o = cuda_sigmoid(X_offset[2 * dim + d]);
    const T g = cuda_tanh(X_offset[3 * dim + d]);
    const T c_prev = C_prev[index];
    const T c = f * c_prev + i * g;
    C[index] = c;
    const T tanh_c = cuda_tanh(c);
    H[index] = o * tanh_c;
  }
}

template <typename T>
63 64 65 66 67 68 69 70 71
__global__ void LSTMUnitGradientKernel(const int nthreads,
                                       const int dim,
                                       const T* C_prev,
                                       const T* X,
                                       const T* C,
                                       const T* C_diff,
                                       const T* H_diff,
                                       T* C_prev_diff,
                                       T* X_diff,
72
                                       const T forget_bias) {
73
  CUDA_KERNEL_LOOP(index, nthreads) {
Z
zchen0211 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    const int n = index / dim;
    const int d = index % dim;
    const T* X_offset = X + 4 * dim * n;
    T* c_prev_diff = C_prev_diff + index;
    T* X_diff_offset = X_diff + 4 * dim * n;
    T* i_diff = X_diff_offset + d;
    T* f_diff = X_diff_offset + 1 * dim + d;
    T* o_diff = X_diff_offset + 2 * dim + d;
    T* g_diff = X_diff_offset + 3 * dim + d;

    const T i = cuda_sigmoid(X_offset[d]);
    const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
    const T o = cuda_sigmoid(X_offset[2 * dim + d]);
    const T g = cuda_tanh(X_offset[3 * dim + d]);
    const T c_prev = C_prev[index];
    const T c = C[index];
    const T tanh_c = cuda_tanh(c);
    const T c_term_diff =
        C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
    *c_prev_diff = c_term_diff * f;
    *i_diff = c_term_diff * g * i * (1 - i);
    *f_diff = c_term_diff * c_prev * f * (1 - f);
    *o_diff = H_diff[index] * tanh_c * o * (1 - o);
    *g_diff = c_term_diff * i * (1 - g * g);
  }
}

101
template <typename T>
Y
Yu Yang 已提交
102
class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
Z
zchen0211 已提交
103 104
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
105
    PADDLE_ENFORCE_EQ(
106 107
        platform::is_gpu_place(ctx.GetPlace()),
        true,
108
        paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace."));
Z
zchen0211 已提交
109

110 111 112 113
    auto* x_tensor = ctx.Input<phi::DenseTensor>("X");
    auto* c_prev_tensor = ctx.Input<phi::DenseTensor>("C_prev");
    auto* c_tensor = ctx.Output<phi::DenseTensor>("C");
    auto* h_tensor = ctx.Output<phi::DenseTensor>("H");
Z
zchen0211 已提交
114

115
    auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
Z
zchen0211 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    int b_size = c_tensor->dims()[0];
    int D = c_tensor->dims()[1];

    const T* X = x_tensor->data<T>();
    const T* C_prev = c_prev_tensor->data<T>();

    T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
    T* H = h_tensor->mutable_data<T>(ctx.GetPlace());

    int block = 512;
    int n = b_size * D;
    int grid = (n + block - 1) / block;

    LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
  }
};

134
template <typename T>
Y
Yu Yang 已提交
135
class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
Z
zchen0211 已提交
136 137
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
138
    PADDLE_ENFORCE_EQ(
139 140
        platform::is_gpu_place(ctx.GetPlace()),
        true,
141
        paddle::platform::errors::PreconditionNotMet("It must use CUDAPlace."));
Z
zchen0211 已提交
142

143 144 145 146
    auto x_tensor = ctx.Input<phi::DenseTensor>("X");
    auto c_prev_tensor = ctx.Input<phi::DenseTensor>("C_prev");
    auto c_tensor = ctx.Input<phi::DenseTensor>("C");
    auto h_tensor = ctx.Input<phi::DenseTensor>("H");
Z
zchen0211 已提交
147

148 149 150 151
    auto hdiff_tensor =
        ctx.Input<phi::DenseTensor>(framework::GradVarName("H"));
    auto cdiff_tensor =
        ctx.Input<phi::DenseTensor>(framework::GradVarName("C"));
Z
zchen0211 已提交
152

153 154
    auto xdiff_tensor =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
Z
zchen0211 已提交
155
    auto c_prev_diff_tensor =
156
        ctx.Output<phi::DenseTensor>(framework::GradVarName("C_prev"));
Z
zchen0211 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170

    auto* X = x_tensor->data<T>();
    auto* C_prev = c_prev_tensor->data<T>();
    auto* C = c_tensor->data<T>();

    auto* H_diff = hdiff_tensor->data<T>();
    auto* C_diff = cdiff_tensor->data<T>();

    auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
    auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());

    int N = c_tensor->dims()[0];
    int D = c_tensor->dims()[1];

171
    auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
Z
zchen0211 已提交
172 173 174 175 176

    int block = 512;
    int n = N * D;
    int grid = (n + block - 1) / block;

177 178
    LSTMUnitGradientKernel<T><<<grid, block>>>(
        n, D, C_prev, X, C, C_diff, H_diff, C_prev_diff, X_diff, forget_bias);
Z
zchen0211 已提交
179 180 181 182 183 184 185
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
186 187
REGISTER_OP_CUDA_KERNEL(lstm_unit,
                        ops::LstmUnitOpCUDAKernel<float>,
Q
QI JUN 已提交
188
                        ops::LstmUnitOpCUDAKernel<double>);
189 190
REGISTER_OP_CUDA_KERNEL(lstm_unit_grad,
                        ops::LstmUnitGradOpCUDAKernel<float>,
Q
QI JUN 已提交
191
                        ops::LstmUnitGradOpCUDAKernel<double>);