device_tracer.cc 31.9 KB
Newer Older
X
Xin Pan 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#include <deque>
16
#include <forward_list>
X
Xin Pan 已提交
17
#include <fstream>
18 19 20 21
#include <mutex>  // NOLINT
#include <string>
#include <thread>  // NOLINT

22
#include "glog/logging.h"
C
chengduo 已提交
23
#include "paddle/fluid/platform/device_tracer.h"
24

25 26
DECLARE_bool(enable_host_event_recorder_hook);

27 28
namespace paddle {
namespace platform {
29 30 31 32

// Used only by DeviceTracer
uint64_t GetThreadIdFromSystemThreadId(uint32_t id);

33
namespace {
X
Xin Pan 已提交
34
// Tracking the nested block stacks of each thread.
W
Wilber 已提交
35 36 37 38 39 40
#ifdef PADDLE_WITH_SW
// sw not supported thread_local
std::deque<int> block_id_stack;
std::deque<Event *> annotation_stack;
#else
// Tracking the nested event stacks.
X
Xin Pan 已提交
41 42
thread_local std::deque<int> block_id_stack;
// Tracking the nested event stacks.
43
thread_local std::deque<Event *> annotation_stack;
W
Wilber 已提交
44
#endif
W
wangchaochaohu 已提交
45 46 47
// stack to strore event sunch as pe and so on
static std::deque<Event *> main_thread_annotation_stack{};
static std::deque<std::string> main_thread_annotation_stack_name{};
48

49 50
std::map<uint32_t, uint64_t> system_thread_id_map;
std::mutex system_thread_id_map_mutex;
51 52 53

std::once_flag tracer_once_flag;
DeviceTracer *tracer = nullptr;
54 55 56 57 58

void PrintCuptiHint() {
  static bool showed = false;
  if (showed) return;
  showed = true;
T
tianshuo78520a 已提交
59
  LOG(WARNING) << "Invalid timestamp occurred. Please try increasing the "
60 61 62
                  "FLAGS_multiple_of_cupti_buffer_size.";
}

63 64 65 66
}  // namespace
#ifdef PADDLE_WITH_CUPTI

namespace {
67 68 69
// The experimental best performance is
// the same size with CUPTI device buffer size(8M)
uint64_t kBufSize = 1024 * 1024 * 8;
70
uint64_t kAlignSize = 8;
71 72
std::unordered_map<CUpti_CallbackId, std::string> runtime_cbid_str,
    driver_cbid_str;
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

#define ALIGN_BUFFER(buffer, align)                                 \
  (((uintptr_t)(buffer) & ((align)-1))                              \
       ? ((buffer) + (align) - ((uintptr_t)(buffer) & ((align)-1))) \
       : (buffer))

#define CUPTI_CALL(call)                                                   \
  do {                                                                     \
    CUptiResult _status = call;                                            \
    if (_status != CUPTI_SUCCESS) {                                        \
      const char *errstr;                                                  \
      dynload::cuptiGetResultString(_status, &errstr);                     \
      fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n", \
              __FILE__, __LINE__, #call, errstr);                          \
      exit(-1);                                                            \
    }                                                                      \
  } while (0)

X
Xin Pan 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
std::string MemcpyKind(CUpti_ActivityMemcpyKind kind) {
  switch (kind) {
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
      return "MEMCPY_HtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
      return "MEMCPY_DtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
      return "MEMCPY_HtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
      return "MEMCPY_AtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
      return "MEMCPY_AtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
      return "MEMCPY_AtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
      return "MEMCPY_DtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
      return "MEMCPY_DtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
      return "MEMCPY_HtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
      return "MEMCPY_PtoP";
    case CUPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT:
      return "MEMCPY_FORCE_INT";
    default:
      break;
  }
  return "MEMCPY";
}

121 122 123 124 125 126 127 128 129 130 131 132 133 134
std::string DriverKind(CUpti_CallbackId cbid) {
  auto iter = driver_cbid_str.find(cbid);
  if (iter == driver_cbid_str.end())
    return "Driver API " + std::to_string(cbid);
  return iter->second;
}

std::string RuntimeKind(CUpti_CallbackId cbid) {
  auto iter = runtime_cbid_str.find(cbid);
  if (iter == runtime_cbid_str.end())
    return "Runtime API " + std::to_string(cbid);
  return iter->second;
}

135 136 137 138
void EnableActivity() {
  // Device activity record is created when CUDA initializes, so we
  // want to enable it before cuInit() or any CUDA runtime call.
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
139 140 141 142 143
  CUPTI_CALL(
      dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
144
  // We don't track these activities for now.
D
Dun 已提交
145
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
146 147
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE));
148 149 150 151 152 153 154 155 156
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
}

void DisableActivity() {
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
157 158 159
  CUPTI_CALL(
      dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE));
160
  // Disable all other activity record kinds.
161
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT));
162 163
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
D
Dun 已提交
164
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
165 166 167
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MARKER));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD));
168 169 170 171
}

void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size,
                              size_t *maxNumRecords) {
172
  uint8_t *buf = reinterpret_cast<uint8_t *>(malloc(kBufSize + kAlignSize));
173 174 175 176 177 178 179
  *size = kBufSize;
  *buffer = ALIGN_BUFFER(buf, kAlignSize);
  *maxNumRecords = 0;
}

void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
                              size_t size, size_t validSize) {
180 181 182
  static std::thread::id cupti_thread_id(0);
  if (cupti_thread_id == std::thread::id(0))
    cupti_thread_id = std::this_thread::get_id();
G
GaoWei8 已提交
183 184 185 186
  PADDLE_ENFORCE_EQ(
      std::this_thread::get_id(), cupti_thread_id,
      platform::errors::PermissionDenied(
          "Only one thread is allowed to call bufferCompleted()."));
187 188 189 190 191 192 193 194 195
  CUptiResult status;
  CUpti_Activity *record = NULL;
  if (validSize > 0) {
    do {
      status = dynload::cuptiActivityGetNextRecord(buffer, validSize, &record);
      if (status == CUPTI_SUCCESS) {
        switch (record->kind) {
          case CUPTI_ACTIVITY_KIND_KERNEL:
          case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
W
wangchaochaohu 已提交
196 197 198 199
#if CUDA_VERSION >= 9000
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel4 *>(record);
#else
200 201
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
W
wangchaochaohu 已提交
202
#endif
Z
ZongwuYang 已提交
203
            tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
204 205 206 207
                                     kernel->deviceId, kernel->streamId,
                                     kernel->correlationId);
            break;
          }
X
Xin Pan 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
          case CUPTI_ACTIVITY_KIND_MEMCPY: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
          case CUPTI_ACTIVITY_KIND_MEMCPY2: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy2 *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
D
Dun 已提交
228 229 230 231 232 233 234 235
          case CUPTI_ACTIVITY_KIND_MEMSET: {
            auto *memset =
                reinterpret_cast<const CUpti_ActivityMemset *>(record);
            tracer->AddKernelRecords("MEMSET", memset->start, memset->end,
                                     memset->deviceId, memset->streamId,
                                     memset->correlationId);
            break;
          }
236 237
          case CUPTI_ACTIVITY_KIND_DRIVER: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
238 239 240
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
241
                  DriverKind(api->cbid), api->start, api->end, -1,
242 243 244
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
245 246 247 248
            break;
          }
          case CUPTI_ACTIVITY_KIND_RUNTIME: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
249 250 251
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
252
                  RuntimeKind(api->cbid), api->start, api->end, -1,
253 254 255
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
256 257
            break;
          }
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
          default: { break; }
        }
      } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
        // Seems not an error in this case.
        break;
      } else {
        CUPTI_CALL(status);
      }
    } while (1);

    size_t dropped;
    CUPTI_CALL(
        dynload::cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped));
    if (dropped != 0) {
      fprintf(stderr, "Dropped %u activity records\n", (unsigned int)dropped);
273
      PrintCuptiHint();
274 275 276 277
    }
  }
  free(buffer);
}
278 279 280

void initCuptiCbidStr();

281 282
}  // namespace

Q
qiaolongfei 已提交
283 284
#endif  // PADDLE_WITH_CUPTI

285 286
class DeviceTracerImpl : public DeviceTracer {
 public:
287 288 289 290 291
  DeviceTracerImpl() : enabled_(false) {
#ifdef PADDLE_WITH_CUPTI
    initCuptiCbidStr();
#endif
  }
292

293
  void AddAnnotation(uint32_t id, Event *event) {
W
Wilber 已提交
294 295 296 297
#ifdef PADDLE_WITH_SW
    std::forward_list<std::pair<uint32_t, Event *>> *local_correlations_pairs =
        nullptr;
#else
298 299
    thread_local std::forward_list<std::pair<uint32_t, Event *>>
        *local_correlations_pairs = nullptr;
W
Wilber 已提交
300
#endif
301 302 303 304 305 306
    if (local_correlations_pairs == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      correlations_pairs.emplace_front();
      local_correlations_pairs = &correlations_pairs.front();
    }
    local_correlations_pairs->push_front(std::make_pair(id, event));
307 308
  }

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
  void AddAnnotations(const std::map<uint64_t, ThreadEvents> &thr_events) {
    for (auto &tmp : active_kind_records_) {
      for (const ActiveKindRecord &r : tmp) {
        auto iter = thr_events.find(r.thread_id);
        if (iter == thr_events.end()) {
          VLOG(10) << __func__ << " " << r.name
                   << " Missing tid: " << r.thread_id;
          continue;
        }
        const ThreadEvents &evts = iter->second;
        auto evt_iter = evts.upper_bound(r.end_ns);
        if (evt_iter == evts.end()) {
          VLOG(10) << __func__ << " Missing Record " << r.name
                   << " tid: " << r.thread_id << " end_ns: " << r.end_ns;
          continue;
        }
        if (evt_iter != evts.begin()) {
          auto prev_iter = std::prev(evt_iter);
          if (prev_iter->first >= r.end_ns) {
            evt_iter = prev_iter;
          } else {
            VLOG(10) << __func__ << " prev end_ns " << prev_iter->first
                     << " end_ns: " << r.end_ns;
          }
        }
        Event *evt = evt_iter->second.first;
        uint64_t start_ns = evt_iter->second.second;
        if (start_ns > r.start_ns) {
          VLOG(10) << __func__ << " Mismatch Record " << r.name
                   << " tid: " << r.thread_id << " start_ns: " << r.start_ns
                   << " end_ns: " << r.end_ns << ", event " << evt->name()
                   << " start_ns: " << start_ns;
          continue;
        }
        VLOG(10) << __func__ << " tid: " << r.thread_id << " Add correlation "
                 << r.correlation_id << "<->" << evt->name();
        AddAnnotation(r.correlation_id, evt);
      }
    }
  }

X
Xin Pan 已提交
350 351 352
  void AddCPURecords(const std::string &anno, uint64_t start_ns,
                     uint64_t end_ns, int64_t device_id, int64_t thread_id) {
    if (anno.empty()) {
M
minqiyang 已提交
353
      VLOG(1) << "Empty timeline annotation.";
354 355
      return;
    }
W
Wilber 已提交
356 357 358
#ifdef PADDLE_WITH_SW
    std::forward_list<CPURecord> *local_cpu_records_ = nullptr;
#else
359
    thread_local std::forward_list<CPURecord> *local_cpu_records_ = nullptr;
W
Wilber 已提交
360
#endif
361 362 363 364 365 366
    if (local_cpu_records_ == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      cpu_records_.emplace_front();
      local_cpu_records_ = &cpu_records_.front();
    }
    local_cpu_records_->push_front(
X
Xin Pan 已提交
367
        CPURecord{anno, start_ns, end_ns, device_id, thread_id});
X
Xin Pan 已提交
368 369
  }

X
Xin Pan 已提交
370
  void AddMemRecords(const std::string &name, uint64_t start_ns,
X
Xin Pan 已提交
371
                     uint64_t end_ns, int64_t device_id, int64_t stream_id,
X
Xin Pan 已提交
372
                     uint32_t correlation_id, uint64_t bytes) {
X
Xin Pan 已提交
373
    // 0 means timestamp information could not be collected for the kernel.
374
    if (start_ns == 0 || end_ns == 0 || start_ns == end_ns) {
M
minqiyang 已提交
375
      VLOG(3) << name << " cannot be traced";
376
      PrintCuptiHint();
X
Xin Pan 已提交
377 378
      return;
    }
379 380 381
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    mem_records_.push_front(MemRecord{name, start_ns, end_ns, device_id,
                                      stream_id, correlation_id, bytes});
X
Xin Pan 已提交
382 383
  }

C
chengduo 已提交
384 385 386 387 388 389 390
  void AddMemInfoRecord(uint64_t start_ns, uint64_t end_ns, size_t bytes,
                        const Place &place, const std::string &alloc_in,
                        const std::string &free_in, int64_t thread_id) {
    if (0 == start_ns || 0 == end_ns) {
      VLOG(3) << alloc_in << ", " << free_in << " Cannot be traced.";
      return;
    }
W
Wilber 已提交
391 392 393
#ifdef PADDLE_WITH_SW
    std::forward_list<MemInfoRecord> *local_mem_info_record = nullptr;
#else
C
chengduo 已提交
394 395
    thread_local std::forward_list<MemInfoRecord> *local_mem_info_record =
        nullptr;
W
Wilber 已提交
396
#endif
C
chengduo 已提交
397 398 399 400 401 402 403 404 405
    if (local_mem_info_record == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      mem_info_record_.emplace_front();
      local_mem_info_record = &mem_info_record_.front();
    }
    local_mem_info_record->emplace_front(MemInfoRecord{
        start_ns, end_ns, bytes, place, thread_id, alloc_in, free_in});
  }

406 407
  void AddActiveKindRecords(const std::string &anno, uint64_t start_ns,
                            uint64_t end_ns, int64_t device_id,
408
                            uint64_t thread_id, uint32_t correlation_id) {
409 410 411 412
    if (anno.empty()) {
      VLOG(1) << "Empty timeline annotation.";
      return;
    }
W
Wilber 已提交
413 414 415
#ifdef PADDLE_WITH_SW
    std::forward_list<ActiveKindRecord> *local_active_kind_records = nullptr;
#else
416 417
    thread_local std::forward_list<ActiveKindRecord>
        *local_active_kind_records = nullptr;
W
Wilber 已提交
418
#endif
419 420 421 422 423 424 425 426 427 428
    if (local_active_kind_records == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      active_kind_records_.emplace_front();
      local_active_kind_records = &active_kind_records_.front();
    }
    //  lock is not needed, only one thread call this function.
    local_active_kind_records->push_front(ActiveKindRecord{
        anno, start_ns, end_ns, device_id, thread_id, correlation_id});
  }

Z
ZongwuYang 已提交
429 430 431
  void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
                        int64_t device_id, int64_t stream_id,
                        uint32_t correlation_id) {
X
Xin Pan 已提交
432
    // 0 means timestamp information could not be collected for the kernel.
433
    if (start == 0 || end == 0 || start == end) {
M
minqiyang 已提交
434
      VLOG(3) << correlation_id << " cannot be traced";
435
      PrintCuptiHint();
X
Xin Pan 已提交
436 437
      return;
    }
438 439
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    kernel_records_.push_front(
Z
ZongwuYang 已提交
440
        KernelRecord{name, start, end, device_id, stream_id, correlation_id});
441 442 443 444 445 446 447 448 449 450 451 452
  }

  bool IsEnabled() {
    std::lock_guard<std::mutex> l(trace_mu_);
    return enabled_;
  }

  void Enable() {
    std::lock_guard<std::mutex> l(trace_mu_);
    if (enabled_) {
      return;
    }
Q
qiaolongfei 已提交
453 454

#ifdef PADDLE_WITH_CUPTI
455 456 457 458 459 460 461 462 463 464 465 466 467 468
    EnableActivity();

    // Register callbacks for buffer requests and completed by CUPTI.
    CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(bufferRequested,
                                                       bufferCompleted));

    CUptiResult ret;
    ret = dynload::cuptiSubscribe(
        &subscriber_, static_cast<CUpti_CallbackFunc>(ApiCallback), this);
    if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) {
      fprintf(stderr, "CUPTI subcriber limit reached.\n");
    } else if (ret != CUPTI_SUCCESS) {
      fprintf(stderr, "Failed to create CUPTI subscriber.\n");
    }
469
    const std::vector<int> runtime_cbids {
470
      CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020,
471
          CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020,
472
          CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020,
D
Dun 已提交
473 474
          CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020,
475 476 477 478 479 480 481 482
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000
#if CUDA_VERSION >= 9000
          ,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000
#endif
    };
483 484 485 486
    const std::vector<int> driver_cbids{CUPTI_DRIVER_TRACE_CBID_cuLaunch,
                                        CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid,
                                        CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel};
    for (auto cbid : runtime_cbids)
487 488
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, cbid));
489 490 491
    for (auto cbid : driver_cbids)
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid));
492
    CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
Q
qiaolongfei 已提交
493
#endif  // PADDLE_WITH_CUPTI
494 495 496
    enabled_ = true;
  }

497 498 499 500 501 502 503 504 505 506 507
  void Reset() {
#ifdef PADDLE_WITH_CUPTI
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
#endif
    std::lock_guard<std::mutex> l(trace_mu_);
    kernel_records_.clear();
    mem_records_.clear();
    correlations_.clear();
    for (auto &tmp : correlations_pairs) tmp.clear();
    for (auto &tmp : cpu_records_) tmp.clear();
C
chengduo 已提交
508
    for (auto &tmp : mem_info_record_) tmp.clear();
509
    for (auto &tmp : active_kind_records_) tmp.clear();
510 511 512 513 514 515 516 517 518 519 520
  }

  void GenEventKernelCudaElapsedTime() {
#ifdef PADDLE_WITH_CUPTI
    if (correlations_.empty())
      for (auto &tmp : correlations_pairs)
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
    for (const KernelRecord &r : kernel_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
521 522 523 524 525
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
526 527 528 529 530 531 532
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
    for (const auto &r : mem_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
533 534 535 536 537
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
538 539 540 541 542 543
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
#endif
  }

X
Xin Pan 已提交
544
  proto::Profile GenProfile(const std::string &profile_path) {
H
Huihuang Zheng 已提交
545 546 547 548 549 550 551 552 553 554
    proto::Profile profile_pb = this->GetProfile();
    std::ofstream profile_f;
    profile_f.open(profile_path,
                   std::ios::out | std::ios::trunc | std::ios::binary);
    profile_pb.SerializeToOstream(&profile_f);
    profile_f.close();
    return profile_pb;
  }

  proto::Profile GetProfile() {
555
    int miss = 0, find = 0;
556 557 558 559
    std::lock_guard<std::mutex> l(trace_mu_);
    proto::Profile profile_pb;
    profile_pb.set_start_ns(start_ns_);
    profile_pb.set_end_ns(end_ns_);
C
chengduo 已提交
560 561
    if (correlations_.empty()) {
      for (auto &tmp : correlations_pairs) {
562
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
C
chengduo 已提交
563 564 565
      }
    }

566 567
    for (const KernelRecord &r : kernel_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
568
      event->set_type(proto::Event::GPUKernel);
569 570 571
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
Y
Yuang Liu 已提交
572
        event->set_detail_info(c->second->attr());
573
        find++;
Z
ZongwuYang 已提交
574
      } else {
575
        VLOG(10) << __func__ << " Missing Kernel Event: " + r.name;
576
        miss++;
Z
ZongwuYang 已提交
577 578
        event->set_name(r.name);
      }
579 580
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
581
      event->set_sub_device_id(r.stream_id);
582
      event->set_device_id(r.device_id);
X
Xin Pan 已提交
583
    }
584 585
    VLOG(1) << __func__ << " KernelRecord event miss: " << miss
            << " find: " << find;
C
chengduo 已提交
586

587
    for (auto &tmp : cpu_records_) {
588 589 590 591 592 593 594 595 596
      for (const CPURecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        event->set_name(r.name);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
597
    }
C
chengduo 已提交
598

599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
    for (auto &tmp : active_kind_records_) {
      for (const ActiveKindRecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        auto c = correlations_.find(r.correlation_id);
        if (c != correlations_.end() && c->second != nullptr) {
          event->set_name(c->second->name());
          event->set_detail_info(r.name);
        } else {
          event->set_name(r.name);
        }
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
    }
616
    miss = find = 0;
X
Xin Pan 已提交
617 618
    for (const MemRecord &r : mem_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
619
      event->set_type(proto::Event::GPUKernel);
620 621 622 623 624 625 626 627 628
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
      } else {
        miss++;
        event->set_name(r.name);
      }
X
Xin Pan 已提交
629 630
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
631
      event->set_sub_device_id(r.stream_id);
X
Xin Pan 已提交
632 633 634
      event->set_device_id(r.device_id);
      event->mutable_memcopy()->set_bytes(r.bytes);
    }
635 636
    VLOG(1) << __func__ << " MemRecord event miss: " << miss
            << " find: " << find;
C
chengduo 已提交
637 638 639 640 641 642 643 644 645 646

    for (auto &tmp : mem_info_record_) {
      for (const auto &r : tmp) {
        auto *event = profile_pb.add_mem_events();
        event->set_device_id(0);
        if (platform::is_cpu_place(r.place)) {
          event->set_place(proto::MemEvent::CPUPlace);
        } else if (platform::is_gpu_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPlace);
          event->set_device_id(
647
              BOOST_GET_CONST(platform::CUDAPlace, r.place).GetDeviceId());
C
chengduo 已提交
648 649
        } else if (platform::is_cuda_pinned_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPinnedPlace);
650 651
        } else if (platform::is_npu_place(r.place)) {
          event->set_place(proto::MemEvent::NPUPlace);
C
chengduo 已提交
652
        } else {
G
GaoWei8 已提交
653 654
          PADDLE_THROW(platform::errors::Unimplemented(
              "The current place is not supported."));
C
chengduo 已提交
655 656 657 658 659 660 661 662 663
        }
        event->set_alloc_in(r.alloc_in);
        event->set_free_in(r.free_in);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_bytes(r.bytes);
        event->set_thread_id(r.thread_id);
      }
    }
664 665 666 667
    return profile_pb;
  }

  void Disable() {
Q
qiaolongfei 已提交
668
#ifdef PADDLE_WITH_CUPTI
669
    // flush might cause additional calls to DeviceTracker.
670 671
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
Q
qiaolongfei 已提交
672 673 674
#endif  // PADDLE_WITH_CUPTI
    std::lock_guard<std::mutex> l(trace_mu_);
#ifdef PADDLE_WITH_CUPTI
675
    DisableActivity();
676
    CUPTI_CALL(dynload::cuptiUnsubscribe(subscriber_));
677
    CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
Q
qiaolongfei 已提交
678
#endif  // PADDLE_WITH_CUPTI
679 680 681 682
    enabled_ = false;
  }

 private:
Q
qiaolongfei 已提交
683
#ifdef PADDLE_WITH_CUPTI
684 685
  static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
                                   CUpti_CallbackId cbid, const void *cbdata) {
686 687 688
    if (LIKELY(FLAGS_enable_host_event_recorder_hook)) {
      return;
    }
689
    auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
690 691 692 693
    DeviceTracerImpl *tracer = reinterpret_cast<DeviceTracerImpl *>(userdata);
    if (cbInfo->callbackSite == CUPTI_API_ENTER) {
      Event *event = CurAnnotation();
      tracer->AddAnnotation(cbInfo->correlationId, event);
694 695 696 697
    }
  }
  CUpti_SubscriberHandle subscriber_;
#endif  // PADDLE_WITH_CUPTI
Q
qiaolongfei 已提交
698 699 700 701
  std::mutex trace_mu_;
  bool enabled_;
  uint64_t start_ns_;
  uint64_t end_ns_;
702 703 704
  std::forward_list<KernelRecord> kernel_records_;
  std::forward_list<MemRecord> mem_records_;
  std::forward_list<std::forward_list<CPURecord>> cpu_records_;
C
chengduo 已提交
705
  std::forward_list<std::forward_list<MemInfoRecord>> mem_info_record_;
706
  std::forward_list<std::forward_list<ActiveKindRecord>> active_kind_records_;
707 708 709
  std::forward_list<std::forward_list<std::pair<uint32_t, Event *>>>
      correlations_pairs;
  std::unordered_map<uint32_t, Event *> correlations_;
710 711
};

Q
qiaolongfei 已提交
712
void CreateTracer(DeviceTracer **t) { *t = new DeviceTracerImpl(); }
713 714 715 716 717 718

DeviceTracer *GetDeviceTracer() {
  std::call_once(tracer_once_flag, CreateTracer, &tracer);
  return tracer;
}

719 720 721 722 723 724
// In order to record PE time, we add main_thread_annotation_stack
// for all event between PE run, we treat it as PE's child Event,
// so when event is not in same thread of PE event, we need add
// father event(PE::run event) for this event
void SetCurAnnotation(Event *event) {
  if (!annotation_stack.empty()) {
725 726 727
    event->set_parent(annotation_stack.back());
    event->set_name(annotation_stack.back()->name() + "/" + event->name());
  }
728 729 730 731 732 733
  if (annotation_stack.empty() && !main_thread_annotation_stack.empty() &&
      main_thread_annotation_stack.back()->thread_id() != event->thread_id()) {
    event->set_parent(main_thread_annotation_stack.back());
    event->set_name(main_thread_annotation_stack.back()->name() + "/" +
                    event->name());
  }
734
  annotation_stack.push_back(event);
W
wangchaochaohu 已提交
735 736 737 738 739 740 741 742 743

  if (event->role() == EventRole::kSpecial) {
    std::string name = event->name();
    if (!main_thread_annotation_stack_name.empty()) {
      name = main_thread_annotation_stack_name.back() + "/" + event->name();
    }
    main_thread_annotation_stack_name.push_back(name);
    main_thread_annotation_stack.push_back(event);
  }
744
}
X
Xin Pan 已提交
745

W
wangchaochaohu 已提交
746
void ClearCurAnnotation() {
747
  if (!main_thread_annotation_stack.empty()) {
748 749 750 751 752 753
    std::string name = annotation_stack.back()->name();
    std::string main_name = main_thread_annotation_stack.back()->name();
    int main_name_len = main_name.length();
    int name_len = name.length();
    int prefix_len = main_name_len - name_len;

754 755 756
    if ((prefix_len > 0 && main_name.at(prefix_len - 1) == '/' &&
         name == main_name.substr(prefix_len, name_len)) ||
        (name == main_name)) {
757 758 759
      main_thread_annotation_stack_name.pop_back();
      main_thread_annotation_stack.pop_back();
    }
W
wangchaochaohu 已提交
760 761 762
  }
  annotation_stack.pop_back();
}
X
Xin Pan 已提交
763

764 765
Event *CurAnnotation() {
  if (annotation_stack.empty()) return nullptr;
X
Xin Pan 已提交
766 767
  return annotation_stack.back();
}
768

769
std::string CurAnnotationName() {
C
chengduo 已提交
770
  if (annotation_stack.empty()) return "Unknown";
771 772
  return annotation_stack.back()->name();
}
X
Xin Pan 已提交
773 774 775 776 777 778

void SetCurBlock(int block_id) { block_id_stack.push_back(block_id); }

void ClearCurBlock() { block_id_stack.pop_back(); }

int BlockDepth() { return block_id_stack.size(); }
779 780 781 782 783 784 785 786

uint32_t GetCurSystemThreadId() {
  std::stringstream ss;
  ss << std::this_thread::get_id();
  uint32_t id = static_cast<uint32_t>(std::stoull(ss.str()));
  return id;
}

787 788
void RecoreCurThreadId(uint64_t id) {
  std::lock_guard<std::mutex> lock(system_thread_id_map_mutex);
789 790 791 792
  auto gid = GetCurSystemThreadId();
  system_thread_id_map[gid] = id;
}

793
uint64_t GetThreadIdFromSystemThreadId(uint32_t id) {
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
  auto it = system_thread_id_map.find(id);
  if (it != system_thread_id_map.end()) return it->second;
  // return origin id if no event is recorded in this thread.
  return static_cast<int32_t>(id);
}

#ifdef PADDLE_WITH_CUPTI
namespace {

void initCuptiCbidStr() {
  static bool called = false;
  if (called) return;
  called = true;
#define REGISTER_RUNTIME_CBID_STR(cbid) \
  runtime_cbid_str[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid

  REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFree_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020);
  REGISTER_RUNTIME_CBID_STR(
      cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020);
C
chengduo 已提交
850 851
  REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020);
852
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010);
853 854 855 856 857 858 859 860 861 862
#if CUDA_VERSION >= 9000
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000);
#endif

#undef REGISTER_RUNTIME_CBID_STR
}
}  // namespace
#endif  // PADDLE_WITH_CUPTI

863 864
}  // namespace platform
}  // namespace paddle