cudnn_lstm_op.cu.cc 20.8 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/generator.h"
C
chengduozh 已提交
16
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/operators/utils.h"
18
#include "paddle/phi/kernels/funcs/math_function.h"
19 20 21 22 23 24
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/miopen_lstm_cache.h"
#endif
W
wanghuancoder 已提交
25 26 27 28

namespace paddle {
namespace platform {
class CUDADeviceContext;
29

W
wanghuancoder 已提交
30 31
}  // namespace platform
}  // namespace paddle
L
liuhongyu 已提交
32 33 34 35 36 37 38

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

G
GaoWei8 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
  bool continuous = true;
  for (size_t i = 0; i < weight_list.size() - 1; ++i) {
    auto *in_data = weight_list[i]->template data<T>();
    auto *in_after_data = weight_list[i + 1]->template data<T>();
    auto in_size = weight_list[i]->numel();
    bool temp = in_data + in_size == in_after_data;
    continuous = continuous && temp;
  }
  return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
  int size = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    auto in_size = weight_list[i]->numel();
    size += in_size;
  }
  return size;
}

template <typename T>
62
void weight_to_tensor(const platform::Place &place, gpuStream_t stream,
G
GaoWei8 已提交
63 64 65 66 67 68 69 70
                      const std::vector<const Tensor *> &weight_list,
                      Tensor *weight) {
  auto weight_data = weight->data<T>();
  int weight_offset = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    const T *in_data = weight_list[i]->data<T>();
    auto in_size = weight_list[i]->numel();

71 72
    memory::Copy(weight->place(), weight_data + weight_offset,
                 weight_list[i]->place(), in_data, in_size * sizeof(T), stream);
G
GaoWei8 已提交
73 74 75 76 77
    weight_offset += in_size;
  }
}

template <typename T>
78
void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
G
GaoWei8 已提交
79 80 81 82 83 84 85 86 87 88
                           std::vector<Tensor *> *weight_grad,
                           const std::vector<const Tensor *> &weight_input,
                           const Tensor *weight) {
  int weight_offset = 0;
  auto *weight_data = weight->data<T>();
  for (size_t i = 0; i < weight_input.size(); ++i) {
    auto in_size = weight_input[i]->numel();
    T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
    const T *src = weight_data + weight_offset;

89 90
    memory::Copy((*weight_grad)[i]->place(), weight_grad_data, weight->place(),
                 src, in_size * sizeof(T), stream);
G
GaoWei8 已提交
91 92 93 94
    weight_offset += in_size;
  }
}

95
template <typename T>
96 97 98
#ifdef PADDLE_WITH_HIP
void LSTMInferece(const bool &has_seq_length, const miopenHandle_t &handle,
#else
99
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
100
#endif
101 102 103 104 105 106
                  const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
                  const T *init_h_data, const T *init_c_data, const T *w_data,
                  T *out_data, T *last_h_data, T *last_c_data,
                  framework::Tensor *workspace_data,
                  const size_t &workspace_size) {
  if (!has_seq_length) {
107 108 109
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
110
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference(
111 112 113 114 115 116
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
#else
117
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference(
118 119 120 121 122
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
123
#endif
124
  } else {
125
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
126 127
    // for inference
    // This interface is used when the input/output is padded.
128
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(),
        init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(),
        w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data,
        rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
        nullptr, nullptr, nullptr, nullptr, workspace_data->data<uint8_t>(),
        workspace_size));
#else
    // CUDNN VERSION has to >=7.2.1
    PADDLE_THROW(platform::errors::Unavailable(
        "The padded input is supported by "
        "cudnnRNNForwardInferenceEx, but it only works when "
        "the version of cudnn is larger than 7.2.1"));
#endif
  }
}

C
chengduozh 已提交
145
template <typename T>
L
liuhongyu 已提交
146 147 148 149 150 151 152 153
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
154 155 156 157
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
172
    int seed = ctx.Attr<int>("seed");
173

174
    if (!is_test) {
175 176 177 178
      if (seed == 0) {
        // If not specify seed, use global Generator to generate seed.
        int device_id = ctx.GetPlace().GetDeviceId();
        auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id);
179
        seed = static_cast<int>(gen_cuda->Random64());
180 181
      }
      // else use `ctx.Attr<int>("seed")` specified seed
182 183
    }

184 185 186 187 188 189
    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
L
liuhongyu 已提交
190 191 192 193

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
194 195 196 197
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
198

G
GaoWei8 已提交
199
    size_t workspace_size;
G
GaoWei8 已提交
200
    size_t reserve_size;
G
GaoWei8 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    Tensor weight_whole;
    T *w_data = nullptr;
    int weight_numel;
    bool w_initialized = false;
    auto place = ctx.GetPlace();
    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    if (is_test && ctx.HasInput("W")) {
      auto *W = ctx.Input<Tensor>("W");
      w_initialized = W->IsInitialized() ? true : false;
      weight_numel = W->numel();
    }
    if (!w_initialized) {
      auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
      bool continuous =
          is_continuous<T, std::vector<const Tensor *>>(weight_list);
      weight_numel = size_sum(weight_list);

      if (!continuous) {
        LOG_FIRST_N(WARNING, 2)
222 223 224
            << "If the memory space of the Input WeightList is not continuous, "
               "less efficient calculation will be called. Please call "
               "flatten_parameters() to make the input memory continuous.";
G
GaoWei8 已提交
225 226 227
        weight_whole.mutable_data<T>({weight_numel}, place);
        weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
        w_data = weight_whole.data<T>();
228 229 230 231 232 233 234 235 236 237 238 239 240
        if (is_test) {  // maybe also reset small weights' ptr for training
          int offset = 0;
          for (size_t i = 0; i < weight_list.size(); ++i) {
            size_t len = weight_list[i]->numel();
            auto dim = weight_list[i]->dims();
            const_cast<Tensor *>(weight_list[i])
                ->ShareDataWith(
                    weight_whole.Slice(static_cast<int64_t>(offset),
                                       static_cast<int64_t>(offset + len)))
                .Resize(dim);
            offset += len;
          }
        }
G
GaoWei8 已提交
241 242 243 244 245 246 247
      } else {
        w_data = const_cast<T *>(weight_list[0]->data<T>());
      }
    } else {
      auto *W = ctx.Input<Tensor>("W");
      w_data = const_cast<T *>(W->data<T>());
    }
G
GaoWei8 已提交
248

249 250 251 252
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel,
                      state_initialized, is_bidirec);
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
253 254 255
                  &reserve_size, state_out);

    framework::Tensor workspace_data_;
256 257
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
258 259 260

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
261 262

    if (is_test) {
263 264 265
      LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
                      init_h_data, init_c_data, w_data, out_data, last_h_data,
                      last_c_data, &workspace_data_, workspace_size);
L
liuhongyu 已提交
266
    } else {
267
      if (!has_seq_length) {
268 269 270
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
271
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining(
272 273 274 275 276 277 278
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
#else
279
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
280 281 282 283
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
G
GaoWei8 已提交
284 285
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
286
#endif
G
GaoWei8 已提交
287
      } else {
288
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
289 290
        // for train
        // This interface is used when the input/output is padded.
291 292 293 294 295 296 297
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx(
            handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.init_h_desc(),
            init_h_data, rnn.init_c_desc(), init_c_data, rnn.weight_desc(),
            w_data, rnn.y_seq_desc(), out_data, rnn.last_h_desc(), last_h_data,
            rnn.last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
            nullptr, nullptr, nullptr, nullptr, workspace_data_.data<uint8_t>(),
            workspace_size, reserve_data, reserve_size));
G
GaoWei8 已提交
298
#else
299 300 301 302
        PADDLE_THROW(platform::errors::Unavailable(
            "The padded input is supported by "
            "cudnnRNNForwardTrainingEx, but it only works when "
            "the version of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
303 304
#endif
      }
L
liuhongyu 已提交
305 306 307 308
    }
  }
};

C
chengduozh 已提交
309
template <typename T>
L
liuhongyu 已提交
310 311 312 313 314 315
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
316 317
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");
G
GaoWei8 已提交
318
    auto weight_list = ctx.MultiInput<Tensor>("WeightList");
G
GaoWei8 已提交
319

L
liuhongyu 已提交
320 321
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
322 323
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
324 325 326 327

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
G
GaoWei8 已提交
328 329
    auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
        framework::GradVarName("WeightList"));
L
liuhongyu 已提交
330 331 332 333 334 335 336 337

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
338 339 340 341 342 343
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
344

G
GaoWei8 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    auto place = ctx.GetPlace();
    int weight_numel = size_sum(weight_list);
    bool continuous =
        is_continuous<T, std::vector<const Tensor *>>(weight_list);

    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    Tensor weight_whole;
    T *weight_data = nullptr;

    if (!continuous) {
      weight_whole.mutable_data<T>({weight_numel}, place);
      weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
      weight_data = weight_whole.data<T>();
    } else {
      weight_data = const_cast<T *>(weight_list[0]->data<T>());
    }

    Tensor weight_grad;
365
    phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
G
GaoWei8 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379
    weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
    zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
    T *weight_grad_data = weight_grad.data<T>();

    int offset = 0;
    for (size_t i = 0; i < weight_grad_list.size(); ++i) {
      size_t len = weight_grad_list[i]->numel();
      auto dim = weight_grad_list[i]->dims();
      weight_grad_list[i]
          ->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
                                            static_cast<int64_t>(offset + len)))
          .Resize(dim);
      offset += len;
    }
L
liuhongyu 已提交
380

G
GaoWei8 已提交
381 382
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
383

G
GaoWei8 已提交
384 385
    if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
L
liuhongyu 已提交
386

G
GaoWei8 已提交
387 388
    if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
L
liuhongyu 已提交
389

G
GaoWei8 已提交
390 391 392 393 394
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
395 396 397 398 399 400 401

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
G
GaoWei8 已提交
402

G
GaoWei8 已提交
403 404 405
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
G
GaoWei8 已提交
406

G
GaoWei8 已提交
407
    size_t workspace_size;
G
GaoWei8 已提交
408
    size_t reserve_size;
G
GaoWei8 已提交
409

410 411 412
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel, true,
                      is_bidirec);
G
GaoWei8 已提交
413

414
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
415 416 417
                  &reserve_size, const_cast<Tensor *>(state_out));

    framework::Tensor workspace_data_;
418 419
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
420
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
421

422
    if (!has_seq_length) {
423 424
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
425
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData(
426 427 428 429 430 431 432 433
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));

434
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
435 436 437 438 439
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          rnn.weight_desc(), weight_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
#else
440
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData(
441 442 443 444 445 446 447
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
448

449
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
450 451 452
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
G
GaoWei8 已提交
453
          weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
454
#endif
G
GaoWei8 已提交
455
    } else {
456
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
457 458
      // for train
      // This interface is used when the input/output is padded.
459
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
G
GaoWei8 已提交
460
          handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
461 462 463 464 465
          out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
G
GaoWei8 已提交
466 467 468
          workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

469
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
G
GaoWei8 已提交
470
          handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
471 472
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
          out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
G
GaoWei8 已提交
473
          rnn.weight_desc(), weight_grad_data,
474
          const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
475
#else
476 477 478 479
      PADDLE_THROW(platform::errors::Unavailable(
          "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
          "cudnnRNNBackwardWeightsEx, but it only works when the version "
          "of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
480 481
#endif
    }
L
liuhongyu 已提交
482 483 484 485 486 487 488
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
489 490 491 492 493
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
#else
G
GaoWei8 已提交
494 495 496 497
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
                        ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
                        ops::CudnnLSTMGPUGradKernel<double>);
498
#endif