cuda_helper.h 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

W
Wilber 已提交
17
#include <functional>
18 19
#include <mutex>  // NOLINT

W
Wilber 已提交
20
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
21
#include "paddle/fluid/platform/dynload/cublas.h"
22
#include "paddle/fluid/platform/enforce.h"
23 24 25 26 27
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace platform {

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
/*
 * Summary: Grid stride looping macro in CUDA kernel
 *
 *  [ Why need this macro? ]
 *
 *    The original looping in CUDA kernel is:
 *
 *    `for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
 *        i += blockDim.x * gridDim.x)`
 *
 *    This for condition is risky. The value of `blockIdx.x * blockDim.x`
 *    may be large, such as over 1GB, the first iteration is no problem here,
 *    but when `i += blockDim.x * gridDim.x` is executed, the value of i
 *    will greater than INT_MAX and overflow becomes negative value, at
 *    this time, the cycle condition `i < (n)` is still satisfied, so it
 *    will cause illegal access to cuda memory.
 *
 *    Here is a real example in ERINE, it will trigger above error.
 *    The related data are:
 *      - blockIdx.x = 2172938
 *      - blockDim.x = 512
 *      - blockIdx.x * blockDim.x = 1112543864
 *      - INT_MAX = 2147483647
 *
 *    So we polish the for condition as follow, the int64_t __index__ will
 *    prevent overflow in the loop increment.
 *
 * Parameters:
 *    - i: loop index
 *    - num: total element numbers
 *
 * Examples:
 *    template <typename T>
 *    __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
 *                      const int d, const int remain) {
 *    CUDA_KERNEL_LOOP(index, num) {
 *      int idx_n = index / d;
 *      int idx_remain = index % remain;
 *      logit_grad[index] *= loss_grad[idx_n * remain + idx_remain];
 *      }
 *    }
 *
*/
71 72

#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type)            \
73
  int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \
74
  for (index_type i = __index__; __index__ < (num);          \
75
       __index__ += blockDim.x * gridDim.x, i = __index__)
76

77 78 79
class CublasHandleHolder {
 public:
  CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
80 81
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasCreate(&handle_));
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasSetStream(handle_, stream));
82 83
#if CUDA_VERSION >= 9000
    if (math_type == CUBLAS_TENSOR_OP_MATH) {
84
      PADDLE_RETRY_CUDA_SUCCESS(
85
          dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
86 87
#if CUDA_VERSION >= 11000
    } else if (math_type == CUBLAS_TF32_TENSOR_OP_MATH) {
88
      PADDLE_RETRY_CUDA_SUCCESS(
89 90
          dynload::cublasSetMathMode(handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif  // CUDA_VERSION >= 11000
91
    }
92
#endif  // CUDA_VERSION >= 9000
93 94
  }

95 96
  const cublasHandle_t& GetCublasHandle() const { return handle_; }

Z
Zeng Jinle 已提交
97
  ~CublasHandleHolder() PADDLE_MAY_THROW {
98
    PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_));
99
  }
100

W
Wilber 已提交
101
  inline void Call(const std::function<void(blasHandle_t)>& callback) const {
102 103 104 105 106 107 108 109 110 111 112 113 114
    std::lock_guard<std::mutex> guard(mtx_);
    callback(handle_);
  }

 private:
  DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);

  cublasHandle_t handle_;
  mutable std::mutex mtx_;
};

}  // namespace platform
}  // namespace paddle