cuda_graph.h 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <atomic>
18 19 20
#include <functional>
#include <memory>
#include <mutex>
21
#include <thread>
22
#include <vector>
23 24
#include "cuda.h"          // NOLINT
#include "cuda_runtime.h"  // NOLINT
25
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
26 27 28 29

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
30
#include "paddle/utils/optional.h"
31 32 33 34

namespace paddle {
namespace platform {

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
  return std::memcmp(&x, &y, sizeof(T)) == 0;
}

class CUDAKernelParams {
 public:
  explicit CUDAKernelParams(const cudaKernelNodeParams *params)
      : params_(params) {}

  const void *func() const { return params_->func; }

  template <typename T>
  T &As(size_t idx) const {
    return *reinterpret_cast<T *>(params_->kernelParams[idx]);
  }

 private:
  const cudaKernelNodeParams *params_;
};

template <typename F, F f>
struct IsSameKernelHelper;

template <typename Return, typename... FuncArgs,
          Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
 private:
63 64
  using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));

65 66 67
  template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
  struct Impl {
    static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
68
      using CompareT = typename std::tuple_element<IDX, FuncArgsTuple>::type;
69 70 71 72 73
      if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
                                    std::get<IDX>(args))) {
        return false;
      }

74
      constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size<TupleT>::value);
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
      return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
    }
  };

  template <typename TupleT, size_t IDX>
  struct Impl<TupleT, IDX, true> {
    static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
      return true;
    }
  };

 public:
  template <typename... Args>
  static bool Compare(const CUDAKernelParams &params, Args... args) {
    constexpr auto kNumArgs = sizeof...(FuncArgs);
    static_assert(kNumArgs == sizeof...(Args), "Argument number not match");

    auto args_tuple = std::make_tuple(args...);
    using TupleT = typename std::decay<decltype(args_tuple)>::type;
    return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
  }
};

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
enum cudaStreamCaptureMode {
  cudaStreamCaptureModeGlobal = 0,
  cudaStreamCaptureModeThreadLocal = 1,
  cudaStreamCaptureModeRelaxed = 2
};
static void ThrowErrorIfNotSupportCUDAGraph() {
  PADDLE_THROW(platform::errors::Unimplemented(
      "CUDA Graph is only supported when CUDA version >= 10.1"));
}
#endif

// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
//       the memory pool.
class CUDAGraph {
  DISABLE_COPY_AND_ASSIGN(CUDAGraph);

  // Since the constructor would throw error is CUDA_VERSION < 10010.
  // The non-static method of CUDAGraph need not check CUDA_VERSION
  // again.
121 122 123 124
  CUDAGraph() {
    ThrowErrorIfNotSupportCUDAGraph();
    id_ = UniqueID();
  }
125 126

 public:
127 128 129
  static constexpr int64_t kDefaultPoolID = 0;
  static constexpr int64_t kInvalidPoolID = -1;

130 131 132 133
  ~CUDAGraph() { Reset(); }

  CUDAGraphID ID() const { return id_; }

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  static int64_t SetMemoryPoolID(int64_t pool_id) {
    auto &pool_id_ = capturing_graph_->pool_id_;
    PADDLE_ENFORCE_EQ(
        pool_id_, kInvalidPoolID,
        phi::errors::InvalidArgument("Cannot reset memory pool id twice, the "
                                     "former memory pool id is %d.",
                                     pool_id_));
    if (pool_id <= kInvalidPoolID) {
      pool_id_ = UniqueMemoryPoolID();
    } else {
      PADDLE_ENFORCE_GE(
          pool_id, kDefaultPoolID,
          phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id));
      pool_id_ = pool_id;
    }
    return pool_id_;
  }

  int64_t PoolID() const { return pool_id_; }

  static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }

156 157 158 159 160 161 162 163 164
  void Replay();

  void Reset();

  void AddResetCallback(std::function<void()> callback) {
    std::lock_guard<std::mutex> guard(mtx_);
    callbacks_.push_back(std::move(callback));
  }

165 166
  void PrintToDotFiles(const std::string &dirname, unsigned int flags);

167 168 169
  static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
                           cudaStreamCaptureMode mode);
  static std::unique_ptr<CUDAGraph> EndCapture();
170 171 172 173

  static void BeginSegmentCapture();
  static void EndSegmentCapture();

174 175 176 177 178 179 180 181 182 183 184 185 186 187
  static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
    capturing_graph_->AddResetCallback(std::move(callback));
  }

  // No need to add CUDA_VERSION macro because capturing_graph_ would
  // always be nullptr (constructor throws error)
  static bool IsCapturing() { return capturing_graph_ != nullptr; }

  static CUDAGraphID CapturingID() { return capturing_graph_->id_; }

  static platform::CUDAPlace CapturingPlace() {
    return capturing_graph_->place_;
  }

188 189 190 191
  // This API can be used to debug which GPU operation is not
  // supported during capturing CUDA Graph.
  static bool IsValidCapturing();

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
    return IsCapturing() &&
           capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
    return false;
#endif
  }

  static bool IsThisThreadCapturing() {
    if (UNLIKELY(IsCapturing())) {
      return IsThreadLocalCapturing()
                 ? capturing_thread_id_.get() == std::this_thread::get_id()
                 : true;
    } else {
      return false;
    }
  }

211 212 213 214
  using SetSeedFunc = std::function<bool(CUDAKernelParams *, bool)>;
  static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
    std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
    capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
215 216
  }

217 218 219 220 221
  static int64_t UniqueMemoryPoolID();

 private:
  static CUDAGraphID UniqueID();

222 223
 private:
#if CUDA_VERSION >= 10010
224 225 226
  std::vector<cudaGraph_t> graphs_;
  std::vector<cudaGraphExec_t> exec_graphs_;
  cudaStreamCaptureMode capture_mode_;
227 228 229
#endif
  cudaStream_t stream_{nullptr};
  platform::CUDAPlace place_;
230
  CUDAGraphID id_;
231
  int64_t pool_id_{kInvalidPoolID};
232 233 234 235
  std::vector<std::function<void()>> callbacks_;
  bool is_reset_{false};
  std::mutex mtx_;

236 237 238 239 240 241
  std::vector<SetSeedFunc> set_seed_funcs_;
  std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
  std::mutex func_mtx_;

  bool is_first_run_{true};

242
  static paddle::optional<std::thread::id> capturing_thread_id_;
243 244 245 246 247 248 249 250
  static std::unique_ptr<CUDAGraph> capturing_graph_;
};

#if CUDA_VERSION >= 10010
class CUDAGraphCaptureModeGuard {
  DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);

 public:
251 252
  explicit CUDAGraphCaptureModeGuard(
      cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {
253
    if (UNLIKELY(CUDAGraph::IsCapturing())) {
254
      PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode));
255 256 257 258 259 260 261 262
      // After cudaThreadExchangeStreamCaptureMode is called,
      // the variable "mode" would be set to the old capturing mode.
      old_mode_ = mode;
    }
  }

  ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW {
    if (UNLIKELY(CUDAGraph::IsCapturing())) {
263
      PADDLE_ENFORCE_GPU_SUCCESS(
264 265 266 267 268 269 270 271 272 273 274 275
          cudaThreadExchangeStreamCaptureMode(&old_mode_));
    }
  }

 private:
  cudaStreamCaptureMode old_mode_;
};
#else
class CUDAGraphCaptureModeGuard {
  DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);

 public:
276 277
  explicit CUDAGraphCaptureModeGuard(
      cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {}
278 279 280 281 282
};
#endif

}  // namespace platform
}  // namespace paddle