cudnn_lstm_op.cu.cc 13.2 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduozh 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
C
chengduozh 已提交
17
#include "paddle/fluid/operators/math/math_function.h"
18
#include "paddle/fluid/operators/utils.h"
W
wanghuancoder 已提交
19 20 21 22 23 24 25

namespace paddle {
namespace platform {
class CUDADeviceContext;
struct CUDAPlace;
}  // namespace platform
}  // namespace paddle
L
liuhongyu 已提交
26 27 28 29 30 31 32

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
                  const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
                  const T *init_h_data, const T *init_c_data, const T *w_data,
                  T *out_data, T *last_h_data, T *last_c_data,
                  framework::Tensor *workspace_data,
                  const size_t &workspace_size) {
  if (!has_seq_length) {
    // for inference
    // This interface is used when the input/output is unpadded.
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
  } else {
#if CUDNN_VERSION >= 7201
    // for inference
    // This interface is used when the input/output is padded.
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
        handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(),
        init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(),
        w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data,
        rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
        nullptr, nullptr, nullptr, nullptr, workspace_data->data<uint8_t>(),
        workspace_size));
#else
    // CUDNN VERSION has to >=7.2.1
    PADDLE_THROW(platform::errors::Unavailable(
        "The padded input is supported by "
        "cudnnRNNForwardInferenceEx, but it only works when "
        "the version of cudnn is larger than 7.2.1"));
#endif
  }
}

C
chengduozh 已提交
70
template <typename T>
L
liuhongyu 已提交
71 72 73 74 75 76 77 78 79 80
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    auto w = ctx.Input<Tensor>("W");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
81 82 83 84
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    const T *w_data = w->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
101
    int seed = ctx.Attr<int>("seed");
102 103 104 105 106 107 108

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
L
liuhongyu 已提交
109 110 111 112

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
113 114 115 116 117
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    int weight_numel = w->numel();
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
118

G
GaoWei8 已提交
119
    size_t workspace_size;
G
GaoWei8 已提交
120
    size_t reserve_size;
G
GaoWei8 已提交
121

122 123 124 125
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel,
                      state_initialized, is_bidirec);
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
126 127 128
                  &reserve_size, state_out);

    framework::Tensor workspace_data_;
129 130
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
131 132 133

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
134 135

    if (is_test) {
136 137 138
      LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
                      init_h_data, init_c_data, w_data, out_data, last_h_data,
                      last_c_data, &workspace_data_, workspace_size);
L
liuhongyu 已提交
139
    } else {
140
      if (!has_seq_length) {
G
GaoWei8 已提交
141 142 143
        // for train
        // This interface is used when the input/output is unpadded.
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
144 145 146 147
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
G
GaoWei8 已提交
148 149 150 151 152 153 154 155
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
      } else {
#if CUDNN_VERSION >= 7201
        // for train
        // This interface is used when the input/output is padded.
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnRNNForwardTrainingEx(
156 157 158 159 160 161 162
                handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data,
                rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
                rnn.weight_desc(), w_data, rnn.y_seq_desc(), out_data,
                rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
                nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
                nullptr, workspace_data_.data<uint8_t>(), workspace_size,
                reserve_data, reserve_size));
G
GaoWei8 已提交
163
#else
164 165 166 167
        PADDLE_THROW(platform::errors::Unavailable(
            "The padded input is supported by "
            "cudnnRNNForwardTrainingEx, but it only works when "
            "the version of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
168 169
#endif
      }
L
liuhongyu 已提交
170 171 172 173
    }
  }
};

C
chengduozh 已提交
174
template <typename T>
L
liuhongyu 已提交
175 176 177 178 179 180 181
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *weight = ctx.Input<Tensor>("W");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
182 183 184
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");

L
liuhongyu 已提交
185 186
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
187 188
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
202 203 204 205 206 207 208
    auto *weight_data = weight->data<T>();
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
209

G
GaoWei8 已提交
210 211 212
    math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
    weight_grad->mutable_data<T>(ctx.GetPlace());
    zero(dev_ctx, weight_grad, static_cast<T>(0.0));
L
liuhongyu 已提交
213

G
GaoWei8 已提交
214 215
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
216

G
GaoWei8 已提交
217 218
    init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad->data<T>();
L
liuhongyu 已提交
219

G
GaoWei8 已提交
220 221
    init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad->data<T>();
L
liuhongyu 已提交
222

G
GaoWei8 已提交
223 224 225 226 227
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
228 229 230 231 232 233 234

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
G
GaoWei8 已提交
235

G
GaoWei8 已提交
236 237 238 239
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
    int weight_numel = weight->numel();
G
GaoWei8 已提交
240

G
GaoWei8 已提交
241
    size_t workspace_size;
G
GaoWei8 已提交
242
    size_t reserve_size;
G
GaoWei8 已提交
243

244 245 246
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel, true,
                      is_bidirec);
G
GaoWei8 已提交
247

248
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
249 250 251
                  &reserve_size, const_cast<Tensor *>(state_out));

    framework::Tensor workspace_data_;
252 253
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
254
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
255

256
    if (!has_seq_length) {
G
GaoWei8 已提交
257 258
      // This interface is used when the input/output is unpadded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
259 260 261 262 263 264 265
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
266 267

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
268 269 270
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
G
GaoWei8 已提交
271 272 273 274 275 276 277 278
          weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
          reserve_size));
    } else {
#if CUDNN_VERSION >= 7201
      // for train
      // This interface is used when the input/output is padded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
          handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
279 280 281 282 283
          out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
G
GaoWei8 已提交
284 285 286 287 288
          workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
          handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
289 290 291 292
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
          out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
          rnn.weight_desc(), weight_grad->data<T>(),
          const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
293
#else
294 295 296 297
      PADDLE_THROW(platform::errors::Unavailable(
          "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
          "cudnnRNNBackwardWeightsEx, but it only works when the version "
          "of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
298 299
#endif
    }
L
liuhongyu 已提交
300 301 302 303 304 305 306
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
G
GaoWei8 已提交
307 308 309 310
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
                        ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
                        ops::CudnnLSTMGPUGradKernel<double>);