cudnn_lstm_op.cu.cc 21.0 KB
Newer Older
P
phlrain 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
liuhongyu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/generator.h"
C
chengduozh 已提交
16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
18
#include "paddle/fluid/operators/utils.h"
19 20 21 22 23 24
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/cudnn_lstm_cache.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/miopen_lstm_cache.h"
#endif
W
wanghuancoder 已提交
25 26 27 28

namespace paddle {
namespace platform {
class CUDADeviceContext;
29

W
wanghuancoder 已提交
30 31
}  // namespace platform
}  // namespace paddle
L
liuhongyu 已提交
32 33 34 35 36 37 38

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

G
GaoWei8 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
  bool continuous = true;
  for (size_t i = 0; i < weight_list.size() - 1; ++i) {
    auto *in_data = weight_list[i]->template data<T>();
    auto *in_after_data = weight_list[i + 1]->template data<T>();
    auto in_size = weight_list[i]->numel();
    bool temp = in_data + in_size == in_after_data;
    continuous = continuous && temp;
  }
  return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
  int size = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    auto in_size = weight_list[i]->numel();
    size += in_size;
  }
  return size;
}

template <typename T>
62
void weight_to_tensor(const platform::Place &place, gpuStream_t stream,
G
GaoWei8 已提交
63 64 65 66 67 68 69 70
                      const std::vector<const Tensor *> &weight_list,
                      Tensor *weight) {
  auto weight_data = weight->data<T>();
  int weight_offset = 0;
  for (size_t i = 0; i < weight_list.size(); ++i) {
    const T *in_data = weight_list[i]->data<T>();
    auto in_size = weight_list[i]->numel();

71 72
    memory::Copy(weight->place(), weight_data + weight_offset,
                 weight_list[i]->place(), in_data, in_size * sizeof(T), stream);
G
GaoWei8 已提交
73 74 75 76 77
    weight_offset += in_size;
  }
}

template <typename T>
78
void weight_to_tensor_list(const platform::Place &place, gpuStream_t stream,
G
GaoWei8 已提交
79 80 81 82 83 84 85 86 87 88
                           std::vector<Tensor *> *weight_grad,
                           const std::vector<const Tensor *> &weight_input,
                           const Tensor *weight) {
  int weight_offset = 0;
  auto *weight_data = weight->data<T>();
  for (size_t i = 0; i < weight_input.size(); ++i) {
    auto in_size = weight_input[i]->numel();
    T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
    const T *src = weight_data + weight_offset;

89 90
    memory::Copy((*weight_grad)[i]->place(), weight_grad_data, weight->place(),
                 src, in_size * sizeof(T), stream);
G
GaoWei8 已提交
91 92 93 94
    weight_offset += in_size;
  }
}

95
template <typename T>
96 97 98
#ifdef PADDLE_WITH_HIP
void LSTMInferece(const bool &has_seq_length, const miopenHandle_t &handle,
#else
99
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
100
#endif
101 102 103 104 105 106
                  const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
                  const T *init_h_data, const T *init_c_data, const T *w_data,
                  T *out_data, T *last_h_data, T *last_c_data,
                  framework::Tensor *workspace_data,
                  const size_t &workspace_size) {
  if (!has_seq_length) {
107 108 109
// for inference
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
110
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardInference(
111 112 113 114 115 116
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
#else
117
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInference(
118 119 120 121 122
        handle, rnn->rnn_desc(), seq_length, rnn->x_descs(), x_data,
        rnn->init_h_desc(), init_h_data, rnn->init_c_desc(), init_c_data,
        rnn->weight_desc(), w_data, rnn->y_descs(), out_data,
        rnn->last_h_desc(), last_h_data, rnn->last_c_desc(), last_c_data,
        workspace_data->data<uint8_t>(), workspace_size));
123
#endif
124
  } else {
125
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
126 127
    // for inference
    // This interface is used when the input/output is padded.
128
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardInferenceEx(
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        handle, rnn->rnn_desc(), rnn->x_seq_desc(), x_data, rnn->init_h_desc(),
        init_h_data, rnn->init_c_desc(), init_c_data, rnn->weight_desc(),
        w_data, rnn->y_seq_desc(), out_data, rnn->last_h_desc(), last_h_data,
        rnn->last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
        nullptr, nullptr, nullptr, nullptr, workspace_data->data<uint8_t>(),
        workspace_size));
#else
    // CUDNN VERSION has to >=7.2.1
    PADDLE_THROW(platform::errors::Unavailable(
        "The padded input is supported by "
        "cudnnRNNForwardInferenceEx, but it only works when "
        "the version of cudnn is larger than 7.2.1"));
#endif
  }
}

C
chengduozh 已提交
145
template <typename T>
L
liuhongyu 已提交
146 147 148 149 150 151 152 153
class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const Tensor *x = ctx.Input<Tensor>("Input");
    const Tensor *init_h = ctx.Input<Tensor>("InitH");
    const Tensor *init_c = ctx.Input<Tensor>("InitC");

    Tensor *out = ctx.Output<Tensor>("Out");
G
GaoWei8 已提交
154 155 156 157
    Tensor *last_h = ctx.Output<Tensor>("LastH");
    Tensor *last_c = ctx.Output<Tensor>("LastC");
    Tensor *reserve = ctx.Output<Tensor>("Reserve");
    Tensor *state_out = ctx.Output<Tensor>("StateOut");
L
liuhongyu 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171

    const T *x_data = x->data<T>();
    const T *init_h_data = init_h->data<T>();
    const T *init_c_data = init_c->data<T>();

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
    T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    bool is_test = ctx.Attr<bool>("is_test");
G
GaoWei8 已提交
172
    int seed = ctx.Attr<int>("seed");
173

174
    if (!is_test) {
175
      int device_id = ctx.GetPlace().GetDeviceId();
176 177 178 179 180 181 182 183 184 185 186 187
      auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
      if (gen_cuda->GetIsInitPy() && seed == 0) {
        // If perform `manual_seed` in python and inner seed is not specified
        // (equals 0), use global generator generated seed.
        seed = static_cast<int>(gen_cuda->Random64());
      } else if (seed == 0) {
        // use random generated seed
        std::random_device rd;
        seed = rd();
      }  // else use `ctx.Attr<int>("seed")` specified seed
    }

188 189 190 191 192 193
    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
L
liuhongyu 已提交
194 195 196 197

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

G
GaoWei8 已提交
198 199 200 201
    int seq_length = x->dims()[0];
    int batch_size = x->dims()[1];
    int input_size = x->dims()[2];
    bool state_initialized = state_out->IsInitialized() ? true : false;
G
GaoWei8 已提交
202

G
GaoWei8 已提交
203
    size_t workspace_size;
G
GaoWei8 已提交
204
    size_t reserve_size;
G
GaoWei8 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
    Tensor weight_whole;
    T *w_data = nullptr;
    int weight_numel;
    bool w_initialized = false;
    auto place = ctx.GetPlace();
    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    if (is_test && ctx.HasInput("W")) {
      auto *W = ctx.Input<Tensor>("W");
      w_initialized = W->IsInitialized() ? true : false;
      weight_numel = W->numel();
    }
    if (!w_initialized) {
      auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
      bool continuous =
          is_continuous<T, std::vector<const Tensor *>>(weight_list);
      weight_numel = size_sum(weight_list);

      if (!continuous) {
        LOG_FIRST_N(WARNING, 2)
226 227 228
            << "If the memory space of the Input WeightList is not continuous, "
               "less efficient calculation will be called. Please call "
               "flatten_parameters() to make the input memory continuous.";
G
GaoWei8 已提交
229 230 231
        weight_whole.mutable_data<T>({weight_numel}, place);
        weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
        w_data = weight_whole.data<T>();
232 233 234 235 236 237 238 239 240 241 242 243 244
        if (is_test) {  // maybe also reset small weights' ptr for training
          int offset = 0;
          for (size_t i = 0; i < weight_list.size(); ++i) {
            size_t len = weight_list[i]->numel();
            auto dim = weight_list[i]->dims();
            const_cast<Tensor *>(weight_list[i])
                ->ShareDataWith(
                    weight_whole.Slice(static_cast<int64_t>(offset),
                                       static_cast<int64_t>(offset + len)))
                .Resize(dim);
            offset += len;
          }
        }
G
GaoWei8 已提交
245 246 247 248 249 250 251
      } else {
        w_data = const_cast<T *>(weight_list[0]->data<T>());
      }
    } else {
      auto *W = ctx.Input<Tensor>("W");
      w_data = const_cast<T *>(W->data<T>());
    }
G
GaoWei8 已提交
252

253 254 255 256
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel,
                      state_initialized, is_bidirec);
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
257 258 259
                  &reserve_size, state_out);

    framework::Tensor workspace_data_;
260 261
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
262 263 264

    auto *reserve_data = reserve->mutable_data<uint8_t>(
        {static_cast<int64_t>(reserve_size)}, ctx.GetPlace());
L
liuhongyu 已提交
265 266

    if (is_test) {
267 268 269
      LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
                      init_h_data, init_c_data, w_data, out_data, last_h_data,
                      last_c_data, &workspace_data_, workspace_size);
L
liuhongyu 已提交
270
    } else {
271
      if (!has_seq_length) {
272 273 274
// for train
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
275
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNForwardTraining(
276 277 278 279 280 281 282
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
#else
283
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTraining(
284 285 286 287
            handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), x_data,
            rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
            rnn.weight_desc(), w_data, rnn.y_descs(), out_data,
            rnn.last_h_desc(), last_h_data, rnn.last_c_desc(), last_c_data,
G
GaoWei8 已提交
288 289
            workspace_data_.data<uint8_t>(), workspace_size, reserve_data,
            reserve_size));
290
#endif
G
GaoWei8 已提交
291
      } else {
292
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
293 294
        // for train
        // This interface is used when the input/output is padded.
295 296 297 298 299 300 301
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNForwardTrainingEx(
            handle, rnn.rnn_desc(), rnn.x_seq_desc(), x_data, rnn.init_h_desc(),
            init_h_data, rnn.init_c_desc(), init_c_data, rnn.weight_desc(),
            w_data, rnn.y_seq_desc(), out_data, rnn.last_h_desc(), last_h_data,
            rnn.last_c_desc(), last_c_data, nullptr, nullptr, nullptr, nullptr,
            nullptr, nullptr, nullptr, nullptr, workspace_data_.data<uint8_t>(),
            workspace_size, reserve_data, reserve_size));
G
GaoWei8 已提交
302
#else
303 304 305 306
        PADDLE_THROW(platform::errors::Unavailable(
            "The padded input is supported by "
            "cudnnRNNForwardTrainingEx, but it only works when "
            "the version of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
307 308
#endif
      }
L
liuhongyu 已提交
309 310 311 312
    }
  }
};

C
chengduozh 已提交
313
template <typename T>
L
liuhongyu 已提交
314 315 316 317 318 319
class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *input = ctx.Input<Tensor>("Input");
    auto *init_h = ctx.Input<Tensor>("InitH");
    auto *init_c = ctx.Input<Tensor>("InitC");
G
GaoWei8 已提交
320 321
    auto *reserve = ctx.Input<Tensor>("Reserve");
    auto *state_out = ctx.Input<Tensor>("StateOut");
G
GaoWei8 已提交
322
    auto weight_list = ctx.MultiInput<Tensor>("WeightList");
G
GaoWei8 已提交
323

L
liuhongyu 已提交
324 325
    auto *out = ctx.Input<Tensor>("Out");
    auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
G
GaoWei8 已提交
326 327
    auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
    auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
L
liuhongyu 已提交
328 329 330 331

    auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
    auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
G
GaoWei8 已提交
332 333
    auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
        framework::GradVarName("WeightList"));
L
liuhongyu 已提交
334 335 336 337 338 339 340 341

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    auto handle = dev_ctx.cudnn_handle();

    auto input_dims = input->dims();
    auto init_h_dims = init_h->dims();
    auto init_c_dims = init_c->dims();

G
GaoWei8 已提交
342 343 344 345 346 347
    auto *init_h_data = init_h->data<T>();
    auto *init_c_data = init_c->data<T>();
    auto *out_data = out->data<T>();
    auto *out_grad_data = out_grad->data<T>();
    auto *last_h_grad_data = last_h_grad->data<T>();
    auto *last_c_grad_data = last_c_grad->data<T>();
L
liuhongyu 已提交
348

G
GaoWei8 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
    auto place = ctx.GetPlace();
    int weight_numel = size_sum(weight_list);
    bool continuous =
        is_continuous<T, std::vector<const Tensor *>>(weight_list);

    auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
                      ctx.device_context())
                      .stream();
    Tensor weight_whole;
    T *weight_data = nullptr;

    if (!continuous) {
      weight_whole.mutable_data<T>({weight_numel}, place);
      weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
      weight_data = weight_whole.data<T>();
    } else {
      weight_data = const_cast<T *>(weight_list[0]->data<T>());
    }

    Tensor weight_grad;
G
GaoWei8 已提交
369
    math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
G
GaoWei8 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383
    weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
    zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
    T *weight_grad_data = weight_grad.data<T>();

    int offset = 0;
    for (size_t i = 0; i < weight_grad_list.size(); ++i) {
      size_t len = weight_grad_list[i]->numel();
      auto dim = weight_grad_list[i]->dims();
      weight_grad_list[i]
          ->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
                                            static_cast<int64_t>(offset + len)))
          .Resize(dim);
      offset += len;
    }
L
liuhongyu 已提交
384

G
GaoWei8 已提交
385 386
    in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
    auto *in_grad_data = in_grad->data<T>();
L
liuhongyu 已提交
387

G
GaoWei8 已提交
388 389
    if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
    auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
L
liuhongyu 已提交
390

G
GaoWei8 已提交
391 392
    if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
    auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
L
liuhongyu 已提交
393

G
GaoWei8 已提交
394 395 396 397 398
    float dropout_prob = ctx.Attr<float>("dropout_prob");
    bool is_bidirec = ctx.Attr<bool>("is_bidirec");
    int hidden_size = ctx.Attr<int>("hidden_size");
    int num_layers = ctx.Attr<int>("num_layers");
    int seed = ctx.Attr<int>("seed");
399 400 401 402 403 404 405

    bool has_seq_length = ctx.HasInput("SequenceLength");
    std::vector<int> SequenceLength;
    if (has_seq_length) {
      auto *sequence_length = ctx.Input<Tensor>("SequenceLength");
      SequenceLength = operators::GetDataFromTensor<int>(sequence_length);
    }
G
GaoWei8 已提交
406

G
GaoWei8 已提交
407 408 409
    int seq_length = input_dims[0];
    int batch_size = input->dims()[1];
    int input_size = input->dims()[2];
G
GaoWei8 已提交
410

G
GaoWei8 已提交
411
    size_t workspace_size;
G
GaoWei8 已提交
412
    size_t reserve_size;
G
GaoWei8 已提交
413

414 415 416
    ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
                      num_layers, dropout_prob, seed, weight_numel, true,
                      is_bidirec);
G
GaoWei8 已提交
417

418
    rnn.Create<T>(handle, ctx.GetPlace(), SequenceLength, &workspace_size,
G
GaoWei8 已提交
419 420 421
                  &reserve_size, const_cast<Tensor *>(state_out));

    framework::Tensor workspace_data_;
422 423
    workspace_data_.mutable_data<uint8_t>(
        {static_cast<int64_t>(workspace_size)}, ctx.GetPlace());
G
GaoWei8 已提交
424
    const uint8_t *reserve_data = reserve->data<uint8_t>();
L
liuhongyu 已提交
425

426
    if (!has_seq_length) {
427 428
// This interface is used when the input/output is unpadded.
#ifdef PADDLE_WITH_HIP
429
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardData(
430 431 432 433 434 435 436 437
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));

438
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenRNNBackwardWeights(
439 440 441 442 443
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          rnn.weight_desc(), weight_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
#else
444
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardData(
445 446 447 448 449 450 451
          handle, rnn.rnn_desc(), seq_length, rnn.y_descs(), out_data,
          rnn.y_descs(), out_grad_data, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_descs(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, workspace_data_.data<uint8_t>(),
          workspace_size, const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
452

453
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeights(
454 455 456
          handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
          workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
G
GaoWei8 已提交
457
          weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
458
#endif
G
GaoWei8 已提交
459
    } else {
460
#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201
G
GaoWei8 已提交
461 462
      // for train
      // This interface is used when the input/output is padded.
463
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardDataEx(
G
GaoWei8 已提交
464
          handle, rnn.rnn_desc(), rnn.y_seq_desc(), out_data, rnn.y_seq_desc(),
465 466 467 468 469
          out_grad_data, nullptr, nullptr, rnn.last_h_desc(), last_h_grad_data,
          rnn.last_c_desc(), last_c_grad_data, rnn.weight_desc(), weight_data,
          rnn.init_h_desc(), init_h_data, rnn.init_c_desc(), init_c_data,
          rnn.x_seq_desc(), in_grad_data, rnn.init_h_desc(), init_h_grad_data,
          rnn.init_c_desc(), init_c_grad_data, nullptr, nullptr,
G
GaoWei8 已提交
470 471 472
          workspace_data_.data<uint8_t>(), workspace_size,
          const_cast<uint8_t *>(reserve_data), reserve_size));

473
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnRNNBackwardWeightsEx(
G
GaoWei8 已提交
474
          handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
475 476
          rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
          out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
G
GaoWei8 已提交
477
          rnn.weight_desc(), weight_grad_data,
478
          const_cast<uint8_t *>(reserve_data), reserve_size));
G
GaoWei8 已提交
479
#else
480 481 482 483
      PADDLE_THROW(platform::errors::Unavailable(
          "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
          "cudnnRNNBackwardWeightsEx, but it only works when the version "
          "of cudnn is larger than 7.2.1"));
G
GaoWei8 已提交
484 485
#endif
    }
L
liuhongyu 已提交
486 487 488 489 490 491 492
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
493 494 495 496 497
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>);
#else
G
GaoWei8 已提交
498 499 500 501
REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>,
                        ops::CudnnLSTMGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>,
                        ops::CudnnLSTMGPUGradKernel<double>);
502
#endif