cudnn_lstm_op.cu.cc 13.1 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduozh 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
C
chengduozh 已提交
17
#include "paddle/fluid/operators/math/math_function.h"
18
#include "paddle/fluid/operators/utils.h"
G
GaoWei8 已提交
19
#include "paddle/fluid/platform/cudnn_desc.h"
G
GaoWei8 已提交
20
#include "paddle/fluid/platform/cudnn_helper.h"
L
liuhongyu 已提交
21 22 23 24 25 26 27

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
                  const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
                  const T *init_h_data, const T *init_c_data, const T *w_data,
                  T *out_data, T *last_h_data, T *last_c_data,
                  framework::Tensor *workspace_data,
                  const size_t &workspace_size) {
  if (!has_seq_length) {
    // for inference
    // This interface is used when the input/output is unpadded.
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference(
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
  } else {
#if CUDNN_VERSION >= 7201
    // for inference
    // This interface is used when the input/output is padded.
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
        handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(),
        init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(),
        w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data,
        rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
        nullptr, nullptr, nullptr, nullptr, workspace_data->data<uint8_t>(),
        workspace_size));
#else
    // CUDNN VERSION has to >=7.2.1
    PADDLE_THROW(platform::errors::Unavailable(
        "The padded input is supported by "
        "cudnnRNNForwardInferenceEx, but it only works when "
        "the version of cudnn is larger than 7.2.1"));
#endif
  }
}

C
chengduozh 已提交
65
template <typename T>
L
liuhongyu 已提交
66 67 68 69 70 71 72 73 74 75
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    auto w = ctx.Input<Tensor>("W");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
76 77 78 79
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    const T *w_data = w->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
96
    int seed = ctx.Attr<int>("seed");
97 98 99 100 101 102 103

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
L
liuhongyu 已提交
104 105 106 107

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
108 109 110 111 112
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    int weight_numel = w->numel();
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
113

G
GaoWei8 已提交
114
    size_t workspace_size;
G
GaoWei8 已提交
115
    size_t reserve_size;
G
GaoWei8 已提交
116

117 118 119 120
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel,
                      state_initialized, is_bidirec);
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
121 122 123
                  &reserve_size, state_out);

    framework::Tensor workspace_data_;
124 125
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
126 127 128

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
129 130

    if (is_test) {
131 132 133
      LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
                      init_h_data, init_c_data, w_data, out_data, last_h_data,
                      last_c_data, &workspace_data_, workspace_size);
L
liuhongyu 已提交
134
    } else {
135
      if (!has_seq_length) {
G
GaoWei8 已提交
136 137 138
        // for train
        // This interface is used when the input/output is unpadded.
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
139 140 141 142
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
G
GaoWei8 已提交
143 144 145 146 147 148 149 150
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
      } else {
#if CUDNN_VERSION >= 7201
        // for train
        // This interface is used when the input/output is padded.
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnRNNForwardTrainingEx(
151 152 153 154 155 156 157
                handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data,
                rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
                rnn.weight_desc(), w_data, rnn.y_seq_desc(), out_data,
                rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
                nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
                nullptr, workspace_data_.data<uint8_t>(), workspace_size,
                reserve_data, reserve_size));
G
GaoWei8 已提交
158
#else
159 160 161 162
        PADDLE_THROW(platform::errors::Unavailable(
            "The padded input is supported by "
            "cudnnRNNForwardTrainingEx, but it only works when "
            "the version of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
163 164
#endif
      }
L
liuhongyu 已提交
165 166 167 168
    }
  }
};

C
chengduozh 已提交
169
template <typename T>
L
liuhongyu 已提交
170 171 172 173 174 175 176
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *weight = ctx.Input<Tensor>("W");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
177 178 179
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");

L
liuhongyu 已提交
180 181
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
182 183
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
197 198 199 200 201 202 203
    auto *weight_data = weight->data<T>();
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
204

G
GaoWei8 已提交
205 206 207
    math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
    weight_grad->mutable_data<T>(ctx.GetPlace());
    zero(dev_ctx, weight_grad, static_cast<T>(0.0));
L
liuhongyu 已提交
208

G
GaoWei8 已提交
209 210
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
211

G
GaoWei8 已提交
212 213
    init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad->data<T>();
L
liuhongyu 已提交
214

G
GaoWei8 已提交
215 216
    init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad->data<T>();
L
liuhongyu 已提交
217

G
GaoWei8 已提交
218 219 220 221 222
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
223 224 225 226 227 228 229

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
G
GaoWei8 已提交
230

G
GaoWei8 已提交
231 232 233 234
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
    int weight_numel = weight->numel();
G
GaoWei8 已提交
235

G
GaoWei8 已提交
236
    size_t workspace_size;
G
GaoWei8 已提交
237
    size_t reserve_size;
G
GaoWei8 已提交
238

239 240 241
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel, true,
                      is_bidirec);
G
GaoWei8 已提交
242

243
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
244 245 246
                  &reserve_size, const_cast<Tensor *>(state_out));

    framework::Tensor workspace_data_;
247 248
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
249
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
250

251
    if (!has_seq_length) {
G
GaoWei8 已提交
252 253
      // This interface is used when the input/output is unpadded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData(
254 255 256 257 258 259 260
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
261 262

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
263 264 265
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
G
GaoWei8 已提交
266 267 268 269 270 271 272 273
          weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
          reserve_size));
    } else {
#if CUDNN_VERSION >= 7201
      // for train
      // This interface is used when the input/output is padded.
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
          handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
274 275 276 277 278
          out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
G
GaoWei8 已提交
279 280 281 282 283
          workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
          handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
284 285 286 287
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
          out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
          rnn.weight_desc(), weight_grad->data<T>(),
          const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
288
#else
289 290 291 292
      PADDLE_THROW(platform::errors::Unavailable(
          "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
          "cudnnRNNBackwardWeightsEx, but it only works when the version "
          "of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
293 294
#endif
    }
L
liuhongyu 已提交
295 296 297 298 299 300 301
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
G
GaoWei8 已提交
302 303 304 305
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
                        ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
                        ops::CudnnLSTMGPUGradKernel<double>);