cuda_graph.cc 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
16 17 18
#include <queue>
#include <unordered_map>
#include <unordered_set>
19 20 21 22 23

namespace paddle {
namespace platform {

std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
24
paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none};
25

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
static std::vector<cudaGraphNode_t> ToposortCUDAGraph(cudaGraph_t graph) {
  size_t num_nodes;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
  std::vector<cudaGraphNode_t> nodes(num_nodes);
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaGraphGetNodes(graph, nodes.data(), &num_nodes));

  size_t num_edges;
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaGraphGetEdges(graph, nullptr, nullptr, &num_edges));
  std::vector<cudaGraphNode_t> from(num_edges), to(num_edges);
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaGraphGetEdges(graph, from.data(), to.data(), &num_edges));

  std::unordered_map<cudaGraphNode_t, std::unordered_set<cudaGraphNode_t>>
      in_edges, out_edges;
  for (auto node : nodes) {
    in_edges[node];
    out_edges[node];
  }

  for (size_t i = 0; i < num_edges; ++i) {
    in_edges[to[i]].insert(from[i]);
    out_edges[from[i]].insert(to[i]);
  }

  std::queue<cudaGraphNode_t> q;
  for (const auto &pair : in_edges) {
    if (pair.second.empty()) {
      q.push(pair.first);
    }
  }

  nodes.clear();
  while (!q.empty()) {
    auto cur = q.front();
    q.pop();
    nodes.push_back(cur);

    for (auto out_node : out_edges.at(cur)) {
      auto &in_nodes = in_edges.at(out_node);
      in_nodes.erase(cur);
      if (in_nodes.empty()) {
        q.push(out_node);
      }
    }
  }
  PADDLE_ENFORCE_EQ(
      nodes.size(), num_nodes,
      phi::errors::InvalidArgument("Toposort error, this may be a bug."));
  return nodes;
}

CUDAGraphID CUDAGraph::UniqueID() {
  static std::atomic<CUDAGraphID> id;
  return id.fetch_add(1);
}

int64_t CUDAGraph::UniqueMemoryPoolID() {
  static std::atomic<int64_t> id(CUDAGraph::kDefaultPoolID + 1);
  return id.fetch_add(1);
}

89 90 91
void CUDAGraph::Reset() {
  if (is_reset_) return;
#if CUDA_VERSION >= 10010
92
  for (auto graph : graphs_) {
93
    PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph));
94
  }
95 96
  graphs_.clear();
  for (auto exec_graph : exec_graphs_) {
97
    PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecDestroy(exec_graph));
98
  }
99
  exec_graphs_.clear();
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
#endif
  // callback should be called in reverse order because the latter added
  // callback may rely on the former added callback.
  for (auto iter = callbacks_.rbegin(); iter != callbacks_.rend(); ++iter) {
    (*iter)();
  }
  callbacks_.clear();
  is_reset_ = true;
}

void CUDAGraph::Replay() {
#if CUDA_VERSION >= 10010
  PADDLE_ENFORCE_EQ(is_reset_, false,
                    errors::PermissionDenied(
                        "Cannot replay the CUDA Graph after reset is called."));
115 116 117 118 119 120 121 122
  size_t n = exec_graphs_.size();
  for (size_t i = 0; i < n; ++i) {
    if (!is_first_run_) {
      for (auto &hook : pre_hooks_[i]) {
        hook(exec_graphs_[i]);
      }
    }
    PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphLaunch(exec_graphs_[i], stream_));
123
  }
124
  is_first_run_ = false;
125 126 127 128 129 130 131 132 133 134
#endif
}

void CUDAGraph::BeginSegmentCapture() {
  ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
  PADDLE_ENFORCE_EQ(
      IsCapturing(), true,
      errors::PermissionDenied("BeginSegmentCapture should be called when CUDA "
                               "Graph is capturing."));
135 136 137 138 139 140 141
  if (IsThreadLocalCapturing()) {
    PADDLE_ENFORCE_EQ(IsThisThreadCapturing(), true,
                      platform::errors::PermissionDenied(
                          "When capturing CUDA Graph in the thread local mode, "
                          "you cannot begin segmented capturing in the thread "
                          "which is not the one that starts the capturing."));
  }
142
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture(
143 144 145 146 147
      capturing_graph_->stream_, capturing_graph_->capture_mode_));
  PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
                    platform::errors::PermissionDenied(
                        "CUDA Graph should not be invalidated."));
  VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
148 149
           << ", segment id " << capturing_graph_->graphs_.size()
           << ", memory pool id " << capturing_graph_->pool_id_;
150 151 152 153 154 155
#endif
}

void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
                             cudaStreamCaptureMode mode) {
  ThrowErrorIfNotSupportCUDAGraph();
156
#if CUDA_VERSION >= 10010
157 158 159 160 161 162 163 164 165
  PADDLE_ENFORCE_EQ(
      IsCapturing(), false,
      errors::PermissionDenied("CUDA Graph can only captured one by one."));
  PADDLE_ENFORCE_NOT_NULL(
      stream, errors::PermissionDenied(
                  "CUDA Graph cannot be captured in default CUDA stream 0."));
  capturing_graph_.reset(new CUDAGraph());
  capturing_graph_->place_ = place;
  capturing_graph_->stream_ = stream;
166
  capturing_graph_->capture_mode_ = mode;
167 168 169 170 171
  if (mode == cudaStreamCaptureModeThreadLocal) {
    capturing_thread_id_ = std::this_thread::get_id();
    VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: "
             << capturing_thread_id_;
  }
172 173
  BeginSegmentCapture();
#endif
174 175
}

176
void CUDAGraph::EndSegmentCapture() {
177 178 179 180
  ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
  PADDLE_ENFORCE_EQ(IsCapturing(), true,
                    errors::PermissionDenied("No CUDA Graph is capturing."));
181
  cudaGraph_t graph;
182
  PADDLE_ENFORCE_GPU_SUCCESS(
183 184
      cudaStreamEndCapture(capturing_graph_->stream_, &graph));
  auto num_nodes = static_cast<size_t>(-1);
185
  PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
186
  if (num_nodes == 0) {
187
    PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(graph));
188
    VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_
189 190
             << ", segment id " << capturing_graph_->graphs_.size()
             << ", memory pool id " << capturing_graph_->pool_id_;
191 192 193
    return;
  }

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
  auto sorted_nodes = ToposortCUDAGraph(graph);
  capturing_graph_->pre_hooks_.emplace_back();
  std::unordered_set<cudaGraphNode_t> visited;
  VLOG(10) << "SetSeedFunc number : "
           << capturing_graph_->set_seed_funcs_.size();
  for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) {
    bool found = false;
    for (auto node : sorted_nodes) {
      if (visited.count(node) > 0) continue;
      cudaGraphNodeType type;
      PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type));
      if (type == cudaGraphNodeTypeKernel) {
        cudaKernelNodeParams params;
        auto err = cudaGraphKernelNodeGetParams(node, &params);
        if (err == cudaErrorInvalidDeviceFunction) {
          continue;
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(err);
        }
        CUDAKernelParams kernel_params(&params);
        if (set_seed_func(&kernel_params, true)) {
          capturing_graph_->pre_hooks_.back().push_back(
              [set_seed_func, node, params](cudaGraphExec_t exec_graph) {
                CUDAKernelParams kernel_params(&params);
                set_seed_func(&kernel_params, false);
                PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams(
                    exec_graph, node, &params));
              });
          visited.insert(node);
          found = true;
          break;
        }
      }
    }
    PADDLE_ENFORCE_EQ(found, true,
                      phi::errors::InvalidArgument(
                          "Cannot find the corresponding random CUDA kernel."));
  }
  capturing_graph_->set_seed_funcs_.clear();

234
  cudaGraphExec_t exec_graph;
235
  PADDLE_ENFORCE_GPU_SUCCESS(
236 237
      cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0));
  VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_
238 239
           << ", segment id " << capturing_graph_->graphs_.size()
           << ", memory pool id " << capturing_graph_->pool_id_;
240 241
  capturing_graph_->graphs_.emplace_back(graph);
  capturing_graph_->exec_graphs_.emplace_back(exec_graph);
242 243 244
#endif
}

245 246
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
  EndSegmentCapture();
247
  capturing_thread_id_ = paddle::none;
248 249 250
  return std::move(capturing_graph_);
}

251
bool CUDAGraph::IsValidCapturing() {
252
#if CUDA_VERSION >= 10010
253 254 255
  if (!IsCapturing()) return false;
  cudaStreamCaptureStatus status;
  CUDAGraphID id;
256
  PADDLE_ENFORCE_GPU_SUCCESS(
257 258
      cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id));
  return status == cudaStreamCaptureStatusActive;
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
#else
  return false;
#endif
}

static std::string ConcatPath(const std::string &dirname,
                              const std::string &filename) {
#ifdef _WIN32
  const char kFileSep[] = "\\";
#else
  const char kFileSep[] = "/";
#endif
  if (!dirname.empty() && dirname.back() == kFileSep[0]) {
    return dirname + filename;
  } else {
    return dirname + kFileSep + filename;
  }
}

void CUDAGraph::PrintToDotFiles(const std::string &dirname,
                                unsigned int flags) {
  ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 11030
  for (size_t i = 0; i < graphs_.size(); ++i) {
    auto filename =
        ConcatPath(dirname, "segment_" + std::to_string(i) + ".dot");
    VLOG(10) << "Save the " << i << "-th segment of graph " << id_ << " to "
             << filename;
287
    PADDLE_ENFORCE_GPU_SUCCESS(
288 289 290 291 292 293 294
        cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags));
  }
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "The print_to_dot_files() method is only supported when CUDA version >= "
      "11.3."));
#endif
295 296
}

297 298
}  // namespace platform
}  // namespace paddle