cudnn_lstm_op.cu.cc 23.6 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/generator.h"
C
chengduozh 已提交
16
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/operators/utils.h"
18
#include "paddle/phi/kernels/funcs/math_function.h"
19 20 21 22 23 24
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/miopen_lstm_cache.h"
#endif
W
wanghuancoder 已提交
25

L
liuhongyu 已提交
26 27 28 29 30 31
namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

G
GaoWei8 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
  bool continuous = true;
  for (size_t i = 0; i < weight_list.size() - 1; ++i) {
    auto *in_data = weight_list[i]->template data<T>();
    auto *in_after_data = weight_list[i + 1]->template data<T>();
    auto in_size = weight_list[i]->numel();
    bool temp = in_data + in_size == in_after_data;
    continuous = continuous && temp;
  }
  return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
  int size = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    auto in_size = weight_list[i]->numel();
    size += in_size;
  }
  return size;
}

template <typename T>
55 56
void weight_to_tensor(const platform::Place &place,
                      gpuStream_t stream,
G
GaoWei8 已提交
57 58 59 60 61 62 63 64
                      const std::vector<const Tensor *> &weight_list,
                      Tensor *weight) {
  auto weight_data = weight->data<T>();
  int weight_offset = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    const T *in_data = weight_list[i]->data<T>();
    auto in_size = weight_list[i]->numel();

65 66 67 68 69 70
    memory::Copy(weight->place(),
                 weight_data + weight_offset,
                 weight_list[i]->place(),
                 in_data,
                 in_size * sizeof(T),
                 stream);
G
GaoWei8 已提交
71 72 73 74 75
    weight_offset += in_size;
  }
}

template <typename T>
76 77
void weight_to_tensor_list(const platform::Place &place,
                           gpuStream_t stream,
G
GaoWei8 已提交
78 79 80 81 82 83 84 85 86 87
                           std::vector<Tensor *> *weight_grad,
                           const std::vector<const Tensor *> &weight_input,
                           const Tensor *weight) {
  int weight_offset = 0;
  auto *weight_data = weight->data<T>();
  for (size_t i = 0; i < weight_input.size(); ++i) {
    auto in_size = weight_input[i]->numel();
    T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
    const T *src = weight_data + weight_offset;

88 89 90 91 92 93
    memory::Copy((*weight_grad)[i]->place(),
                 weight_grad_data,
                 weight->place(),
                 src,
                 in_size * sizeof(T),
                 stream);
G
GaoWei8 已提交
94 95 96 97
    weight_offset += in_size;
  }
}

98
template <typename T>
99
#ifdef PADDLE_WITH_HIP
100 101
void LSTMInferece(const bool &has_seq_length,
                  const miopenHandle_t &handle,
102
#else
103 104
void LSTMInferece(const bool &has_seq_length,
                  const cudnnHandle_t &handle,
105
#endif
106 107 108 109 110 111 112 113 114
                  const int &seq_length,
                  ScopedRNNBase *rnn,
                  const T *x_data,
                  const T *init_h_data,
                  const T *init_c_data,
                  const T *w_data,
                  T *out_data,
                  T *last_h_data,
                  T *last_c_data,
115 116 117
                  framework::Tensor *workspace_data,
                  const size_t &workspace_size) {
  if (!has_seq_length) {
118 119 120
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
121
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference(
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        handle,
        rnn->rnn_desc(),
        seq_length,
        rnn->x_descs(),
        x_data,
        rnn->init_h_desc(),
        init_h_data,
        rnn->init_c_desc(),
        init_c_data,
        rnn->weight_desc(),
        w_data,
        rnn->y_descs(),
        out_data,
        rnn->last_h_desc(),
        last_h_data,
        rnn->last_c_desc(),
        last_c_data,
        workspace_data->data<uint8_t>(),
        workspace_size));
141
#else
142
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference(
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        handle,
        rnn->rnn_desc(),
        seq_length,
        rnn->x_descs(),
        x_data,
        rnn->init_h_desc(),
        init_h_data,
        rnn->init_c_desc(),
        init_c_data,
        rnn->weight_desc(),
        w_data,
        rnn->y_descs(),
        out_data,
        rnn->last_h_desc(),
        last_h_data,
        rnn->last_c_desc(),
        last_c_data,
        workspace_data->data<uint8_t>(),
        workspace_size));
162
#endif
163
  } else {
164
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
165 166
    // for inference
    // This interface is used when the input/output is padded.
167
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        handle,
        rnn->rnn_desc(),
        rnn->x_seq_desc(),
        x_data,
        rnn->init_h_desc(),
        init_h_data,
        rnn->init_c_desc(),
        init_c_data,
        rnn->weight_desc(),
        w_data,
        rnn->y_seq_desc(),
        out_data,
        rnn->last_h_desc(),
        last_h_data,
        rnn->last_c_desc(),
        last_c_data,
        nullptr,
        nullptr,
        nullptr,
        nullptr,
        nullptr,
        nullptr,
        nullptr,
        nullptr,
        workspace_data->data<uint8_t>(),
193 194 195 196 197 198 199 200 201 202 203
        workspace_size));
#else
    // CUDNN VERSION has to >=7.2.1
    PADDLE_THROW(platform::errors::Unavailable(
        "The padded input is supported by "
        "cudnnRNNForwardInferenceEx, but it only works when "
        "the version of cudnn is larger than 7.2.1"));
#endif
  }
}

C
chengduozh 已提交
204
template <typename T>
L
liuhongyu 已提交
205 206 207 208 209 210 211 212
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
213 214 215 216
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
231
    int seed = ctx.Attr<int>("seed");
232

233
    if (!is_test) {
234 235 236 237
      if (seed == 0) {
        // If not specify seed, use global Generator to generate seed.
        int device_id = ctx.GetPlace().GetDeviceId();
        auto gen_cuda = paddle::framework::DefaultCUDAGenerator(device_id);
238
        seed = static_cast<int>(gen_cuda->Random64());
239 240
      }
      // else use `ctx.Attr<int>("seed")` specified seed
241 242
    }

243 244 245 246 247 248
    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
L
liuhongyu 已提交
249 250 251 252

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
253 254 255 256
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
257

G
GaoWei8 已提交
258
    size_t workspace_size;
G
GaoWei8 已提交
259
    size_t reserve_size;
G
GaoWei8 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    Tensor weight_whole;
    T *w_data = nullptr;
    int weight_numel;
    bool w_initialized = false;
    auto place = ctx.GetPlace();
    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    if (is_test && ctx.HasInput("W")) {
      auto *W = ctx.Input<Tensor>("W");
      w_initialized = W->IsInitialized() ? true : false;
      weight_numel = W->numel();
    }
    if (!w_initialized) {
      auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
      bool continuous =
          is_continuous<T, std::vector<const Tensor *>>(weight_list);
      weight_numel = size_sum(weight_list);

      if (!continuous) {
        LOG_FIRST_N(WARNING, 2)
281 282 283
            << "If the memory space of the Input WeightList is not continuous, "
               "less efficient calculation will be called. Please call "
               "flatten_parameters() to make the input memory continuous.";
G
GaoWei8 已提交
284 285 286
        weight_whole.mutable_data<T>({weight_numel}, place);
        weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
        w_data = weight_whole.data<T>();
287 288 289 290 291 292 293 294 295 296 297 298 299
        if (is_test) {  // maybe also reset small weights' ptr for training
          int offset = 0;
          for (size_t i = 0; i < weight_list.size(); ++i) {
            size_t len = weight_list[i]->numel();
            auto dim = weight_list[i]->dims();
            const_cast<Tensor *>(weight_list[i])
                ->ShareDataWith(
                    weight_whole.Slice(static_cast<int64_t>(offset),
                                       static_cast<int64_t>(offset + len)))
                .Resize(dim);
            offset += len;
          }
        }
G
GaoWei8 已提交
300 301 302 303 304 305 306
      } else {
        w_data = const_cast<T *>(weight_list[0]->data<T>());
      }
    } else {
      auto *W = ctx.Input<Tensor>("W");
      w_data = const_cast<T *>(W->data<T>());
    }
G
GaoWei8 已提交
307

308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    ScopedRNNBase rnn(seq_length,
                      batch_size,
                      input_size,
                      hidden_size,
                      num_layers,
                      dropout_prob,
                      seed,
                      weight_numel,
                      state_initialized,
                      is_bidirec);
    rnn.Create<T>(handle,
                  ctx.GetPlace(),
                  SequenceLength,
                  &workspace_size,
                  &reserve_size,
                  state_out);
G
GaoWei8 已提交
324 325

    framework::Tensor workspace_data_;
326 327
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
328 329 330

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
331 332

    if (is_test) {
333 334 335 336 337 338 339 340 341 342 343 344 345
      LSTMInferece<T>(has_seq_length,
                      handle,
                      seq_length,
                      &rnn,
                      x_data,
                      init_h_data,
                      init_c_data,
                      w_data,
                      out_data,
                      last_h_data,
                      last_c_data,
                      &workspace_data_,
                      workspace_size);
L
liuhongyu 已提交
346
    } else {
347
      if (!has_seq_length) {
348 349 350
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
351
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining(
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
            handle,
            rnn.rnn_desc(),
            seq_length,
            rnn.x_descs(),
            x_data,
            rnn.init_h_desc(),
            init_h_data,
            rnn.init_c_desc(),
            init_c_data,
            rnn.weight_desc(),
            w_data,
            rnn.y_descs(),
            out_data,
            rnn.last_h_desc(),
            last_h_data,
            rnn.last_c_desc(),
            last_c_data,
            workspace_data_.data<uint8_t>(),
            workspace_size,
            reserve_data,
372 373
            reserve_size));
#else
374
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
            handle,
            rnn.rnn_desc(),
            seq_length,
            rnn.x_descs(),
            x_data,
            rnn.init_h_desc(),
            init_h_data,
            rnn.init_c_desc(),
            init_c_data,
            rnn.weight_desc(),
            w_data,
            rnn.y_descs(),
            out_data,
            rnn.last_h_desc(),
            last_h_data,
            rnn.last_c_desc(),
            last_c_data,
            workspace_data_.data<uint8_t>(),
            workspace_size,
            reserve_data,
G
GaoWei8 已提交
395
            reserve_size));
396
#endif
G
GaoWei8 已提交
397
      } else {
398
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
399 400
        // for train
        // This interface is used when the input/output is padded.
401
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx(
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
            handle,
            rnn.rnn_desc(),
            rnn.x_seq_desc(),
            x_data,
            rnn.init_h_desc(),
            init_h_data,
            rnn.init_c_desc(),
            init_c_data,
            rnn.weight_desc(),
            w_data,
            rnn.y_seq_desc(),
            out_data,
            rnn.last_h_desc(),
            last_h_data,
            rnn.last_c_desc(),
            last_c_data,
            nullptr,
            nullptr,
            nullptr,
            nullptr,
            nullptr,
            nullptr,
            nullptr,
            nullptr,
            workspace_data_.data<uint8_t>(),
            workspace_size,
            reserve_data,
            reserve_size));
G
GaoWei8 已提交
430
#else
431 432 433 434
        PADDLE_THROW(platform::errors::Unavailable(
            "The padded input is supported by "
            "cudnnRNNForwardTrainingEx, but it only works when "
            "the version of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
435 436
#endif
      }
L
liuhongyu 已提交
437 438 439 440
    }
  }
};

C
chengduozh 已提交
441
template <typename T>
L
liuhongyu 已提交
442 443 444 445 446 447
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
448 449
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");
G
GaoWei8 已提交
450
    auto weight_list = ctx.MultiInput<Tensor>("WeightList");
G
GaoWei8 已提交
451

L
liuhongyu 已提交
452 453
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
454 455
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
456 457 458 459

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
G
GaoWei8 已提交
460 461
    auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
        framework::GradVarName("WeightList"));
L
liuhongyu 已提交
462 463 464 465 466 467 468 469

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
470 471 472 473 474 475
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
476

G
GaoWei8 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
    auto place = ctx.GetPlace();
    int weight_numel = size_sum(weight_list);
    bool continuous =
        is_continuous<T, std::vector<const Tensor *>>(weight_list);

    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    Tensor weight_whole;
    T *weight_data = nullptr;

    if (!continuous) {
      weight_whole.mutable_data<T>({weight_numel}, place);
      weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
      weight_data = weight_whole.data<T>();
    } else {
      weight_data = const_cast<T *>(weight_list[0]->data<T>());
    }

    Tensor weight_grad;
497
    phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
G
GaoWei8 已提交
498 499 500 501 502 503 504 505 506 507 508 509 510 511
    weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
    zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
    T *weight_grad_data = weight_grad.data<T>();

    int offset = 0;
    for (size_t i = 0; i < weight_grad_list.size(); ++i) {
      size_t len = weight_grad_list[i]->numel();
      auto dim = weight_grad_list[i]->dims();
      weight_grad_list[i]
          ->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
                                            static_cast<int64_t>(offset + len)))
          .Resize(dim);
      offset += len;
    }
L
liuhongyu 已提交
512

G
GaoWei8 已提交
513 514
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
515

G
GaoWei8 已提交
516 517
    if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
L
liuhongyu 已提交
518

G
GaoWei8 已提交
519 520
    if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
L
liuhongyu 已提交
521

G
GaoWei8 已提交
522 523 524 525 526
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
527 528 529 530 531 532 533

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
G
GaoWei8 已提交
534

G
GaoWei8 已提交
535 536 537
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
G
GaoWei8 已提交
538

G
GaoWei8 已提交
539
    size_t workspace_size;
G
GaoWei8 已提交
540
    size_t reserve_size;
G
GaoWei8 已提交
541

542 543 544 545 546 547 548 549 550
    ScopedRNNBase rnn(seq_length,
                      batch_size,
                      input_size,
                      hidden_size,
                      num_layers,
                      dropout_prob,
                      seed,
                      weight_numel,
                      true,
551
                      is_bidirec);
G
GaoWei8 已提交
552

553 554 555 556 557 558
    rnn.Create<T>(handle,
                  ctx.GetPlace(),
                  SequenceLength,
                  &workspace_size,
                  &reserve_size,
                  const_cast<Tensor *>(state_out));
G
GaoWei8 已提交
559 560

    framework::Tensor workspace_data_;
561 562
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
563
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
564

565
    if (!has_seq_length) {
566 567
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
568
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData(
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
          handle,
          rnn.rnn_desc(),
          seq_length,
          rnn.y_descs(),
          out_data,
          rnn.y_descs(),
          out_grad_data,
          rnn.last_h_desc(),
          last_h_grad_data,
          rnn.last_c_desc(),
          last_c_grad_data,
          rnn.weight_desc(),
          weight_data,
          rnn.init_h_desc(),
          init_h_data,
          rnn.init_c_desc(),
          init_c_data,
          rnn.x_descs(),
          in_grad_data,
          rnn.init_h_desc(),
          init_h_grad_data,
          rnn.init_c_desc(),
          init_c_grad_data,
          workspace_data_.data<uint8_t>(),
          workspace_size,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
596

597
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
          handle,
          rnn.rnn_desc(),
          seq_length,
          rnn.x_descs(),
          input->data<T>(),
          rnn.init_h_desc(),
          init_h->data<T>(),
          rnn.y_descs(),
          out->data<T>(),
          rnn.weight_desc(),
          weight_grad_data,
          workspace_data_.data<uint8_t>(),
          workspace_size,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
613
#else
614
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData(
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
          handle,
          rnn.rnn_desc(),
          seq_length,
          rnn.y_descs(),
          out_data,
          rnn.y_descs(),
          out_grad_data,
          rnn.last_h_desc(),
          last_h_grad_data,
          rnn.last_c_desc(),
          last_c_grad_data,
          rnn.weight_desc(),
          weight_data,
          rnn.init_h_desc(),
          init_h_data,
          rnn.init_c_desc(),
          init_c_data,
          rnn.x_descs(),
          in_grad_data,
          rnn.init_h_desc(),
          init_h_grad_data,
          rnn.init_c_desc(),
          init_c_grad_data,
          workspace_data_.data<uint8_t>(),
          workspace_size,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
G
GaoWei8 已提交
642

643
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
          handle,
          rnn.rnn_desc(),
          seq_length,
          rnn.x_descs(),
          input->data<T>(),
          rnn.init_h_desc(),
          init_h->data<T>(),
          rnn.y_descs(),
          out->data<T>(),
          workspace_data_.data<uint8_t>(),
          workspace_size,
          rnn.weight_desc(),
          weight_grad_data,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
659
#endif
G
GaoWei8 已提交
660
    } else {
661
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
662 663
      // for train
      // This interface is used when the input/output is padded.
664
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
          handle,
          rnn.rnn_desc(),
          rnn.y_seq_desc(),
          out_data,
          rnn.y_seq_desc(),
          out_grad_data,
          nullptr,
          nullptr,
          rnn.last_h_desc(),
          last_h_grad_data,
          rnn.last_c_desc(),
          last_c_grad_data,
          rnn.weight_desc(),
          weight_data,
          rnn.init_h_desc(),
          init_h_data,
          rnn.init_c_desc(),
          init_c_data,
          rnn.x_seq_desc(),
          in_grad_data,
          rnn.init_h_desc(),
          init_h_grad_data,
          rnn.init_c_desc(),
          init_c_grad_data,
          nullptr,
          nullptr,
          workspace_data_.data<uint8_t>(),
          workspace_size,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
G
GaoWei8 已提交
695

696
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
697 698 699 700 701 702 703 704 705 706 707 708 709 710
          handle,
          rnn.rnn_desc(),
          rnn.x_seq_desc(),
          input->data<T>(),
          rnn.init_h_desc(),
          init_h->data<T>(),
          rnn.y_seq_desc(),
          out->data<T>(),
          workspace_data_.data<uint8_t>(),
          workspace_size,
          rnn.weight_desc(),
          weight_grad_data,
          const_cast<uint8_t *>(reserve_data),
          reserve_size));
G
GaoWei8 已提交
711
#else
712 713 714 715
      PADDLE_THROW(platform::errors::Unavailable(
          "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
          "cudnnRNNBackwardWeightsEx, but it only works when the version "
          "of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
716 717
#endif
    }
L
liuhongyu 已提交
718 719 720 721 722 723 724
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
725 726 727 728 729
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
#else
730 731
REGISTER_OP_CUDA_KERNEL(cudnn_lstm,
                        ops::CudnnLSTMGPUKernel<float>,
G
GaoWei8 已提交
732
                        ops::CudnnLSTMGPUKernel<double>);
733 734
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad,
                        ops::CudnnLSTMGPUGradKernel<float>,
G
GaoWei8 已提交
735
                        ops::CudnnLSTMGPUGradKernel<double>);
736
#endif