rnn_grad_kernel.cu.cc 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rnn_grad_kernel.h"

17
#include "paddle/fluid/operators/utils.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/gpu/rnn_functor.h"

namespace phi {

#ifdef PADDLE_WITH_HIP
template <typename T>
void TensorToPermutedWeight(const Place &place,
                            gpuStream_t stream,
                            const DenseTensor &tensor,
                            std::vector<DenseTensor *> *weight_grad_list,
                            const gpuRNNMode_t rnn_mode,
                            bool is_bidirec) {
  if (is_bidirec) {
    for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
      auto tmp = (*weight_grad_list)[i + 1];
      (*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
      (*weight_grad_list)[i + 2] = tmp;
    }
  }
  size_t weight_offset = 0;
  for (size_t i = 0; i < weight_grad_list->size(); ++i) {
    auto numel_size = (*weight_grad_list)[i]->numel();
    DenseTensor temp;
    temp.Resize({numel_size});
    temp.ShareDataWith(tensor.Slice(weight_offset, weight_offset + numel_size));

    if (rnn_mode == miopenLSTM) {
      std::vector<DenseTensor> split_tensor = temp.Chunk(4, 0);
      WeightListToTensor<T>(
          place,
          stream,
          {split_tensor[0], split_tensor[1], split_tensor[3], split_tensor[2]},
          (*weight_grad_list)[i]);
    } else if (rnn_mode == miopenGRU) {
      std::vector<DenseTensor> split_tensor = temp.Chunk(3, 0);
      WeightListToTensor<T>(place,
                            stream,
                            {split_tensor[1], split_tensor[0], split_tensor[2]},
                            (*weight_grad_list)[i]);
    } else {
      WeightListToTensor<T>(place, stream, {temp}, (*weight_grad_list)[i]);
    }
    weight_offset += numel_size;
  }
  if (is_bidirec) {
    for (size_t i = 0; i < weight_grad_list->size(); i += 4) {
      auto tmp = (*weight_grad_list)[i + 1];
      (*weight_grad_list)[i + 1] = (*weight_grad_list)[i + 2];
      (*weight_grad_list)[i + 2] = tmp;
    }
  }
}
#endif

template <typename T, typename Context>
void RnnGradKernel(const Context &dev_ctx,
                   const DenseTensor &x,
                   const std::vector<const DenseTensor *> &pre_state,
                   const std::vector<const DenseTensor *> &weight_list,
81
                   const paddle::optional<DenseTensor> &sequence_length,
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
                   const DenseTensor &out,
                   const DenseTensor &dropout_state,
                   const DenseTensor &reserve,
                   const DenseTensor &out_grad,
                   const std::vector<const DenseTensor *> &state_grad,
                   float dropout_prob,
                   bool is_bidirec,
                   int input_size,
                   int hidden_size,
                   int num_layers,
                   const std::string &mode,
                   int seed,
                   bool is_test,
                   DenseTensor *x_grad,
                   std::vector<DenseTensor *> pre_state_grad,
                   std::vector<DenseTensor *> weight_grad_list) {
#ifdef PADDLE_WITH_HIP
  miopenRNNMode_t rnn_mode = miopenLSTM;
  if (mode == "LSTM")
    rnn_mode = miopenLSTM;
  else if (mode == "GRU")
    rnn_mode = miopenGRU;
  else if (mode == "RNN_RELU")
    rnn_mode = miopenRNNRELU;
  else if (mode == "RNN_TANH")
    rnn_mode = miopenRNNTANH;
#else
  cudnnRNNMode_t rnn_mode = CUDNN_LSTM;
  if (mode == "LSTM")
    rnn_mode = CUDNN_LSTM;
  else if (mode == "GRU")
    rnn_mode = CUDNN_GRU;
  else if (mode == "RNN_RELU")
    rnn_mode = CUDNN_RNN_RELU;
  else if (mode == "RNN_TANH")
    rnn_mode = CUDNN_RNN_TANH;
#endif
  else
    PADDLE_THROW(phi::errors::InvalidArgument(
        "rnn_mode should be LSTM, GRU, RNN_RELU or RNN_TANH, but received: "
        "%s.",
        mode));
  auto handle = dev_ctx.cudnn_handle();
  auto place = dev_ctx.GetPlace();
  auto weight_numel = std::accumulate(
      weight_list.begin(),
      weight_list.end(),
      0,
      [](int64_t num, const DenseTensor *t) { return num + t->numel(); });
  bool continuous =
      IsContinuous<T, std::vector<const DenseTensor *>>(weight_list);
  auto stream = dev_ctx.stream();
  DenseTensor weight_whole;
  T *weight_data = nullptr;

#ifdef PADDLE_WITH_HIP
  // Need to permute weight, set continuous to false
  continuous = false;
#endif

  if (!continuous) {
    weight_whole.Resize({weight_numel});
    dev_ctx.template Alloc<T>(&weight_whole);
#ifdef PADDLE_WITH_HIP
    // MIOPEN need to permute weight for miopenLSTM or miopenGRU
    std::vector<const DenseTensor *> weight_list_tmp = weight_list;
    WeightToPermutedTensor<T>(
        place, stream, &weight_list_tmp, &weight_whole, rnn_mode, is_bidirec);
#else
    WeightToTensor<T>(place, stream, weight_list, &weight_whole);
#endif
    weight_data = weight_whole.data<T>();
  } else {
    weight_data = const_cast<T *>(weight_list[0]->data<T>());
  }

  DenseTensor weight_grad = Full<T>(dev_ctx, {weight_numel}, 0);
  T *weight_grad_data = weight_grad.data<T>();

#ifdef PADDLE_WITH_HIP
  // MIOPEN need to permute weight_grad_list, so do not share data with
  // weight_grad
  for (size_t i = 0; i < weight_grad_list.size(); ++i) {
    dev_ctx.template Alloc<T>(weight_grad_list[i]);
  }
#else
  int offset = 0;
  for (size_t i = 0; i < weight_grad_list.size(); ++i) {
    size_t len = weight_grad_list[i]->numel();
    auto dim = weight_grad_list[i]->dims();
    weight_grad_list[i]
        ->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
                                          static_cast<int64_t>(offset + len)))
        .Resize(dim);
    offset += len;
  }
#endif

  DenseTensor input_grad_value;
  if (!x_grad) {
    x_grad = &input_grad_value;
    x_grad->Resize(x.dims());
  }

  auto *init_h_data = pre_state[0]->data<T>();
  // auto *last_h_data = state[0]->data<T>();
  auto *last_h_grad_data = state_grad[0]->data<T>();
  const T *init_c_data = nullptr;
  // const T *last_c_data = nullptr;
  const T *last_c_grad_data = nullptr;
  T *init_h_grad_data = pre_state_grad.size() != 0 && pre_state_grad[0]
                            ? dev_ctx.template Alloc<T>(pre_state_grad[0])
                            : nullptr;
  T *init_c_grad_data = nullptr;
#ifdef PADDLE_WITH_HIP
  if (rnn_mode == miopenLSTM) {
#else
  if (rnn_mode == CUDNN_LSTM) {
#endif
    init_c_data = pre_state[1]->data<T>();
    // last_c_data = state[1]->data<T>();
    last_c_grad_data = state_grad[1]->data<T>();
    init_c_grad_data = pre_state_grad.size() >= 2 && pre_state_grad[1]
                           ? dev_ctx.template Alloc<T>(pre_state_grad[1])
                           : nullptr;
  }
  auto *out_data = out.data<T>();
  auto *out_grad_data = out_grad.data<T>();

  // need check exist
  T *x_grad_data = nullptr;
  if (x_grad) {
    x_grad_data = dev_ctx.template Alloc<T>(x_grad);
  }

  bool has_seq_length = sequence_length.is_initialized();
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_EQ(
      has_seq_length,
      false,
      phi::errors::InvalidArgument("ROCm do not support SequenceLength yet."));
#endif
  std::vector<int> SequenceLength;
  if (has_seq_length) {
    SequenceLength =
        paddle::operators::GetDataFromTensor<int>(sequence_length.get_ptr());
  }

  auto input_dims = x.dims();
  int seq_length = input_dims[0];
  int batch_size = input_dims[1];
  int input_size_local = input_dims[2];

  size_t workspace_size;
  size_t reserve_size;

  RNNDescriptors rnn(seq_length,
                     batch_size,
                     input_size_local,
                     hidden_size,
                     num_layers,
                     dropout_prob,
                     seed,
                     weight_numel,
                     rnn_mode,
                     is_bidirec,
                     is_test);

  rnn.Create<T>(handle,
251
                dev_ctx,
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
                SequenceLength,
                &workspace_size,
                &reserve_size,
                const_cast<DenseTensor *>(&dropout_state));

  DenseTensor workspace_data_ =
      Empty<uint8_t>(dev_ctx, {static_cast<int64_t>(workspace_size)});
  const uint8_t *reserve_data = reserve.data<uint8_t>();

  if (!has_seq_length) {
    if (x_grad) {
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::miopenRNNBackwardData(
              handle,
              rnn.rnn_desc(),
              seq_length,
              rnn.y_descs(),
              out_data,
              rnn.y_descs(),
              out_grad_data,
              rnn.last_h_desc(),
              last_h_grad_data,
              rnn.last_c_desc(),
              last_c_grad_data,
              rnn.weight_desc(),
              weight_data,
              rnn.init_h_desc(),
              init_h_data,
              rnn.init_c_desc(),
              init_c_data,
              rnn.x_descs(),
              x_grad_data,
              rnn.init_h_desc(),
              init_h_grad_data,
              rnn.init_c_desc(),
              init_c_grad_data,
              workspace_data_.data<uint8_t>(),
              workspace_size,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
#else
      // This interface is used when the input/output is unpadded.
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::cudnnRNNBackwardData(
              handle,
              rnn.rnn_desc(),
              seq_length,
              rnn.y_descs(),
              out_data,
              rnn.y_descs(),
              out_grad_data,
              rnn.last_h_desc(),
              last_h_grad_data,
              rnn.last_c_desc(),
              last_c_grad_data,
              rnn.weight_desc(),
              weight_data,
              rnn.init_h_desc(),
              init_h_data,
              rnn.init_c_desc(),
              init_c_data,
              rnn.x_descs(),
              x_grad_data,
              rnn.init_h_desc(),
              init_h_grad_data,
              rnn.init_c_desc(),
              init_c_grad_data,
              workspace_data_.data<uint8_t>(),
              workspace_size,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
#endif
    }
    if (!weight_grad_list.empty()) {
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::miopenRNNBackwardWeights(
              handle,
              rnn.rnn_desc(),
              seq_length,
              rnn.x_descs(),
              x.data<T>(),
              rnn.init_h_desc(),
              init_h_data,
              rnn.y_descs(),
              out.data<T>(),
              rnn.weight_desc(),
              weight_grad_data,
              workspace_data_.data<uint8_t>(),
              workspace_size,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
      // permute weight grad list from weight grad tensor
      TensorToPermutedWeight<T>(
          place, stream, weight_grad, &weight_grad_list, rnn_mode, is_bidirec);
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::cudnnRNNBackwardWeights(
              handle,
              rnn.rnn_desc(),
              seq_length,
              rnn.x_descs(),
              x.data<T>(),
              rnn.init_h_desc(),
              init_h_data,
              rnn.y_descs(),
              out.data<T>(),
              workspace_data_.data<uint8_t>(),
              workspace_size,
              rnn.weight_desc(),
              weight_grad_data,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
#endif
    }
  } else {
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201
    // for train
    // This interface is used when the input/output is padded.
    if (x_grad) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::cudnnRNNBackwardDataEx(
              handle,
              rnn.rnn_desc(),
              rnn.y_seq_desc(),
              out_data,
              rnn.y_seq_desc(),
              out_grad_data,
              nullptr,
              nullptr,
              rnn.last_h_desc(),
              last_h_grad_data,
              rnn.last_c_desc(),
              last_c_grad_data,
              rnn.weight_desc(),
              weight_data,
              rnn.init_h_desc(),
              init_h_data,
              rnn.init_c_desc(),
              init_c_data,
              rnn.x_seq_desc(),
              x_grad_data,
              rnn.init_h_desc(),
              init_h_grad_data,
              rnn.init_c_desc(),
              init_c_grad_data,
              nullptr,
              nullptr,
              workspace_data_.data<uint8_t>(),
              workspace_size,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
    }

    if (!weight_grad_list.empty()) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::cudnnRNNBackwardWeightsEx(
              handle,
              rnn.rnn_desc(),
              rnn.x_seq_desc(),
              x.data<T>(),
              rnn.init_h_desc(),
              init_h_data,
              rnn.y_seq_desc(),
              out.data<T>(),
              workspace_data_.data<uint8_t>(),
              workspace_size,
              rnn.weight_desc(),
              weight_grad_data,
              const_cast<uint8_t *>(reserve_data),
              reserve_size));
    }
#else
    PADDLE_THROW(phi::errors::Unavailable(
        "The padded input of rnn is supported by cudnnRNNBackwardDataEx, "
        "cudnnRNNBackwardWeightsEx, but it only works when the version "
        "of cudnn is larger than 7.2.1"));
#endif
  }
}

}  // namespace phi

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(rnn_grad, GPU, ALL_LAYOUT, phi::RnnGradKernel, float) {}
#else
PD_REGISTER_KERNEL(
    rnn_grad, GPU, ALL_LAYOUT, phi::RnnGradKernel, float, double) {}
#endif