device_tracer.cc 27.4 KB
Newer Older
X
Xin Pan 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#include <deque>
16
#include <forward_list>
X
Xin Pan 已提交
17
#include <fstream>
18
#include <list>
19
#include <map>
20
#include <mutex>  // NOLINT
21
#include <numeric>
22
#include <sstream>
23 24
#include <string>
#include <thread>  // NOLINT
25 26
#include <unordered_map>
#include <utility>
27 28
#include <vector>

29
#include "glog/logging.h"
30
#include "google/protobuf/text_format.h"
31
#include "paddle/fluid/framework/block_desc.h"
C
chengduo 已提交
32 33
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/profiler.h"
34 35 36 37 38
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace platform {
namespace {
X
Xin Pan 已提交
39 40 41
// Tracking the nested block stacks of each thread.
thread_local std::deque<int> block_id_stack;
// Tracking the nested event stacks.
42 43 44
thread_local std::deque<Event *> annotation_stack;

std::map<uint32_t, int32_t> system_thread_id_map;
45 46 47

std::once_flag tracer_once_flag;
DeviceTracer *tracer = nullptr;
48 49 50 51 52 53 54 55 56

void PrintCuptiHint() {
  static bool showed = false;
  if (showed) return;
  showed = true;
  LOG(WARNING) << "Invalid timestamp occured. Please try increasing the "
                  "FLAGS_multiple_of_cupti_buffer_size.";
}

57 58 59 60
}  // namespace
#ifdef PADDLE_WITH_CUPTI

namespace {
61 62 63
// The experimental best performance is
// the same size with CUPTI device buffer size(8M)
uint64_t kBufSize = 1024 * 1024 * 8;
64
uint64_t kAlignSize = 8;
65 66
std::unordered_map<CUpti_CallbackId, std::string> runtime_cbid_str,
    driver_cbid_str;
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

#define ALIGN_BUFFER(buffer, align)                                 \
  (((uintptr_t)(buffer) & ((align)-1))                              \
       ? ((buffer) + (align) - ((uintptr_t)(buffer) & ((align)-1))) \
       : (buffer))

#define CUPTI_CALL(call)                                                   \
  do {                                                                     \
    CUptiResult _status = call;                                            \
    if (_status != CUPTI_SUCCESS) {                                        \
      const char *errstr;                                                  \
      dynload::cuptiGetResultString(_status, &errstr);                     \
      fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n", \
              __FILE__, __LINE__, #call, errstr);                          \
      exit(-1);                                                            \
    }                                                                      \
  } while (0)

X
Xin Pan 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
std::string MemcpyKind(CUpti_ActivityMemcpyKind kind) {
  switch (kind) {
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
      return "MEMCPY_HtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
      return "MEMCPY_DtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
      return "MEMCPY_HtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
      return "MEMCPY_AtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
      return "MEMCPY_AtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
      return "MEMCPY_AtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
      return "MEMCPY_DtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
      return "MEMCPY_DtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
      return "MEMCPY_HtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
      return "MEMCPY_PtoP";
    case CUPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT:
      return "MEMCPY_FORCE_INT";
    default:
      break;
  }
  return "MEMCPY";
}

115 116 117 118 119 120 121 122 123 124 125 126 127 128
std::string DriverKind(CUpti_CallbackId cbid) {
  auto iter = driver_cbid_str.find(cbid);
  if (iter == driver_cbid_str.end())
    return "Driver API " + std::to_string(cbid);
  return iter->second;
}

std::string RuntimeKind(CUpti_CallbackId cbid) {
  auto iter = runtime_cbid_str.find(cbid);
  if (iter == runtime_cbid_str.end())
    return "Runtime API " + std::to_string(cbid);
  return iter->second;
}

129 130 131 132
void EnableActivity() {
  // Device activity record is created when CUDA initializes, so we
  // want to enable it before cuInit() or any CUDA runtime call.
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
133 134 135 136 137
  CUPTI_CALL(
      dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
138
  // We don't track these activities for now.
D
Dun 已提交
139
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
140 141
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE));
142 143 144 145 146 147 148 149 150
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
}

void DisableActivity() {
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
151 152 153
  CUPTI_CALL(
      dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE));
154
  // Disable all other activity record kinds.
155
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT));
156 157
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
D
Dun 已提交
158
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
159 160 161
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MARKER));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD));
162 163 164 165
}

void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size,
                              size_t *maxNumRecords) {
166
  uint8_t *buf = reinterpret_cast<uint8_t *>(malloc(kBufSize + kAlignSize));
167 168 169 170 171 172 173
  *size = kBufSize;
  *buffer = ALIGN_BUFFER(buf, kAlignSize);
  *maxNumRecords = 0;
}

void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
                              size_t size, size_t validSize) {
174 175 176 177 178
  static std::thread::id cupti_thread_id(0);
  if (cupti_thread_id == std::thread::id(0))
    cupti_thread_id = std::this_thread::get_id();
  PADDLE_ENFORCE_EQ(std::this_thread::get_id(), cupti_thread_id,
                    "Only one thread is allowed to call bufferCompleted()");
179 180 181 182 183 184 185 186 187
  CUptiResult status;
  CUpti_Activity *record = NULL;
  if (validSize > 0) {
    do {
      status = dynload::cuptiActivityGetNextRecord(buffer, validSize, &record);
      if (status == CUPTI_SUCCESS) {
        switch (record->kind) {
          case CUPTI_ACTIVITY_KIND_KERNEL:
          case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
W
wangchaochaohu 已提交
188 189 190 191
#if CUDA_VERSION >= 9000
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel4 *>(record);
#else
192 193
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
W
wangchaochaohu 已提交
194
#endif
Z
ZongwuYang 已提交
195
            tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
196 197 198 199
                                     kernel->deviceId, kernel->streamId,
                                     kernel->correlationId);
            break;
          }
X
Xin Pan 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
          case CUPTI_ACTIVITY_KIND_MEMCPY: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
          case CUPTI_ACTIVITY_KIND_MEMCPY2: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy2 *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
D
Dun 已提交
220 221 222 223 224 225 226 227
          case CUPTI_ACTIVITY_KIND_MEMSET: {
            auto *memset =
                reinterpret_cast<const CUpti_ActivityMemset *>(record);
            tracer->AddKernelRecords("MEMSET", memset->start, memset->end,
                                     memset->deviceId, memset->streamId,
                                     memset->correlationId);
            break;
          }
228 229
          case CUPTI_ACTIVITY_KIND_DRIVER: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
230 231 232
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
233
                  DriverKind(api->cbid), api->start, api->end, -1,
234 235 236
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
237 238 239 240
            break;
          }
          case CUPTI_ACTIVITY_KIND_RUNTIME: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
241 242 243
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
244
                  RuntimeKind(api->cbid), api->start, api->end, -1,
245 246 247
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
248 249
            break;
          }
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
          default: { break; }
        }
      } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
        // Seems not an error in this case.
        break;
      } else {
        CUPTI_CALL(status);
      }
    } while (1);

    size_t dropped;
    CUPTI_CALL(
        dynload::cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped));
    if (dropped != 0) {
      fprintf(stderr, "Dropped %u activity records\n", (unsigned int)dropped);
265
      PrintCuptiHint();
266 267 268 269
    }
  }
  free(buffer);
}
270 271 272

void initCuptiCbidStr();

273 274
}  // namespace

Q
qiaolongfei 已提交
275 276
#endif  // PADDLE_WITH_CUPTI

277 278
class DeviceTracerImpl : public DeviceTracer {
 public:
279 280 281 282 283
  DeviceTracerImpl() : enabled_(false) {
#ifdef PADDLE_WITH_CUPTI
    initCuptiCbidStr();
#endif
  }
284

285 286 287 288 289 290 291 292 293
  void AddAnnotation(uint32_t id, Event *event) {
    thread_local std::forward_list<std::pair<uint32_t, Event *>>
        *local_correlations_pairs = nullptr;
    if (local_correlations_pairs == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      correlations_pairs.emplace_front();
      local_correlations_pairs = &correlations_pairs.front();
    }
    local_correlations_pairs->push_front(std::make_pair(id, event));
294 295
  }

X
Xin Pan 已提交
296 297 298
  void AddCPURecords(const std::string &anno, uint64_t start_ns,
                     uint64_t end_ns, int64_t device_id, int64_t thread_id) {
    if (anno.empty()) {
M
minqiyang 已提交
299
      VLOG(1) << "Empty timeline annotation.";
300 301
      return;
    }
302 303 304 305 306 307 308
    thread_local std::forward_list<CPURecord> *local_cpu_records_ = nullptr;
    if (local_cpu_records_ == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      cpu_records_.emplace_front();
      local_cpu_records_ = &cpu_records_.front();
    }
    local_cpu_records_->push_front(
X
Xin Pan 已提交
309
        CPURecord{anno, start_ns, end_ns, device_id, thread_id});
X
Xin Pan 已提交
310 311
  }

X
Xin Pan 已提交
312
  void AddMemRecords(const std::string &name, uint64_t start_ns,
X
Xin Pan 已提交
313
                     uint64_t end_ns, int64_t device_id, int64_t stream_id,
X
Xin Pan 已提交
314
                     uint32_t correlation_id, uint64_t bytes) {
X
Xin Pan 已提交
315
    // 0 means timestamp information could not be collected for the kernel.
316
    if (start_ns == 0 || end_ns == 0 || start_ns == end_ns) {
M
minqiyang 已提交
317
      VLOG(3) << name << " cannot be traced";
318
      PrintCuptiHint();
X
Xin Pan 已提交
319 320
      return;
    }
321 322 323
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    mem_records_.push_front(MemRecord{name, start_ns, end_ns, device_id,
                                      stream_id, correlation_id, bytes});
X
Xin Pan 已提交
324 325
  }

C
chengduo 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
  void AddMemInfoRecord(uint64_t start_ns, uint64_t end_ns, size_t bytes,
                        const Place &place, const std::string &alloc_in,
                        const std::string &free_in, int64_t thread_id) {
    if (0 == start_ns || 0 == end_ns) {
      VLOG(3) << alloc_in << ", " << free_in << " Cannot be traced.";
      return;
    }
    thread_local std::forward_list<MemInfoRecord> *local_mem_info_record =
        nullptr;
    if (local_mem_info_record == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      mem_info_record_.emplace_front();
      local_mem_info_record = &mem_info_record_.front();
    }
    local_mem_info_record->emplace_front(MemInfoRecord{
        start_ns, end_ns, bytes, place, thread_id, alloc_in, free_in});
  }

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
  void AddActiveKindRecords(const std::string &anno, uint64_t start_ns,
                            uint64_t end_ns, int64_t device_id,
                            int64_t thread_id, uint32_t correlation_id) {
    if (anno.empty()) {
      VLOG(1) << "Empty timeline annotation.";
      return;
    }
    thread_local std::forward_list<ActiveKindRecord>
        *local_active_kind_records = nullptr;
    if (local_active_kind_records == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      active_kind_records_.emplace_front();
      local_active_kind_records = &active_kind_records_.front();
    }
    //  lock is not needed, only one thread call this function.
    local_active_kind_records->push_front(ActiveKindRecord{
        anno, start_ns, end_ns, device_id, thread_id, correlation_id});
  }

Z
ZongwuYang 已提交
363 364 365
  void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
                        int64_t device_id, int64_t stream_id,
                        uint32_t correlation_id) {
X
Xin Pan 已提交
366
    // 0 means timestamp information could not be collected for the kernel.
367
    if (start == 0 || end == 0 || start == end) {
M
minqiyang 已提交
368
      VLOG(3) << correlation_id << " cannot be traced";
369
      PrintCuptiHint();
X
Xin Pan 已提交
370 371
      return;
    }
372 373
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    kernel_records_.push_front(
Z
ZongwuYang 已提交
374
        KernelRecord{name, start, end, device_id, stream_id, correlation_id});
375 376 377 378 379 380 381 382 383 384 385 386
  }

  bool IsEnabled() {
    std::lock_guard<std::mutex> l(trace_mu_);
    return enabled_;
  }

  void Enable() {
    std::lock_guard<std::mutex> l(trace_mu_);
    if (enabled_) {
      return;
    }
Q
qiaolongfei 已提交
387 388

#ifdef PADDLE_WITH_CUPTI
389 390 391 392 393 394 395 396 397 398 399 400 401 402
    EnableActivity();

    // Register callbacks for buffer requests and completed by CUPTI.
    CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(bufferRequested,
                                                       bufferCompleted));

    CUptiResult ret;
    ret = dynload::cuptiSubscribe(
        &subscriber_, static_cast<CUpti_CallbackFunc>(ApiCallback), this);
    if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) {
      fprintf(stderr, "CUPTI subcriber limit reached.\n");
    } else if (ret != CUPTI_SUCCESS) {
      fprintf(stderr, "Failed to create CUPTI subscriber.\n");
    }
403 404
    const std::vector<int> cbids {
      CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020,
405
          CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020,
406
          CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020,
D
Dun 已提交
407 408
          CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020,
409 410 411 412 413 414 415 416 417 418 419
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000
#if CUDA_VERSION >= 9000
          ,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000
#endif
    };
    for (auto cbid : cbids)
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, cbid));
420
    CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
Q
qiaolongfei 已提交
421
#endif  // PADDLE_WITH_CUPTI
422 423 424
    enabled_ = true;
  }

425 426 427 428 429 430 431 432 433 434 435
  void Reset() {
#ifdef PADDLE_WITH_CUPTI
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
#endif
    std::lock_guard<std::mutex> l(trace_mu_);
    kernel_records_.clear();
    mem_records_.clear();
    correlations_.clear();
    for (auto &tmp : correlations_pairs) tmp.clear();
    for (auto &tmp : cpu_records_) tmp.clear();
C
chengduo 已提交
436
    for (auto &tmp : mem_info_record_) tmp.clear();
437
    for (auto &tmp : active_kind_records_) tmp.clear();
438 439 440 441 442 443 444 445 446 447 448
  }

  void GenEventKernelCudaElapsedTime() {
#ifdef PADDLE_WITH_CUPTI
    if (correlations_.empty())
      for (auto &tmp : correlations_pairs)
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
    for (const KernelRecord &r : kernel_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
449 450 451 452 453
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
454 455 456 457 458 459 460
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
    for (const auto &r : mem_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
461 462 463 464 465
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
466 467 468 469 470 471
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
#endif
  }

X
Xin Pan 已提交
472
  proto::Profile GenProfile(const std::string &profile_path) {
473
    int miss = 0, find = 0;
474 475 476 477
    std::lock_guard<std::mutex> l(trace_mu_);
    proto::Profile profile_pb;
    profile_pb.set_start_ns(start_ns_);
    profile_pb.set_end_ns(end_ns_);
C
chengduo 已提交
478 479
    if (correlations_.empty()) {
      for (auto &tmp : correlations_pairs) {
480
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
C
chengduo 已提交
481 482 483
      }
    }

484 485
    for (const KernelRecord &r : kernel_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
486
      event->set_type(proto::Event::GPUKernel);
487 488 489 490 491
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
Z
ZongwuYang 已提交
492
      } else {
493 494
        VLOG(10) << "Missing Kernel Event: " + r.name;
        miss++;
Z
ZongwuYang 已提交
495 496
        event->set_name(r.name);
      }
497 498
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
499
      event->set_sub_device_id(r.stream_id);
500
      event->set_device_id(r.device_id);
X
Xin Pan 已提交
501
    }
502
    VLOG(1) << "KernelRecord event miss: " << miss << " find: " << find;
C
chengduo 已提交
503

504
    for (auto &tmp : cpu_records_) {
505 506 507 508 509 510 511 512 513
      for (const CPURecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        event->set_name(r.name);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
514
    }
C
chengduo 已提交
515

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
    for (auto &tmp : active_kind_records_) {
      for (const ActiveKindRecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        auto c = correlations_.find(r.correlation_id);
        if (c != correlations_.end() && c->second != nullptr) {
          event->set_name(c->second->name());
          event->set_detail_info(r.name);
        } else {
          event->set_name(r.name);
        }
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
    }
533
    miss = find = 0;
X
Xin Pan 已提交
534 535
    for (const MemRecord &r : mem_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
536
      event->set_type(proto::Event::GPUKernel);
537 538 539 540 541 542 543 544 545
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
      } else {
        miss++;
        event->set_name(r.name);
      }
X
Xin Pan 已提交
546 547
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
548
      event->set_sub_device_id(r.stream_id);
X
Xin Pan 已提交
549 550 551
      event->set_device_id(r.device_id);
      event->mutable_memcopy()->set_bytes(r.bytes);
    }
552
    VLOG(1) << "MemRecord event miss: " << miss << " find: " << find;
C
chengduo 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577

    for (auto &tmp : mem_info_record_) {
      for (const auto &r : tmp) {
        auto *event = profile_pb.add_mem_events();
        event->set_device_id(0);
        if (platform::is_cpu_place(r.place)) {
          event->set_place(proto::MemEvent::CPUPlace);
        } else if (platform::is_gpu_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPlace);
          event->set_device_id(
              boost::get<platform::CUDAPlace>(r.place).GetDeviceId());
        } else if (platform::is_cuda_pinned_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPinnedPlace);
        } else {
          PADDLE_THROW("The current place is not supported.");
        }
        event->set_alloc_in(r.alloc_in);
        event->set_free_in(r.free_in);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_bytes(r.bytes);
        event->set_thread_id(r.thread_id);
      }
    }

X
Xin Pan 已提交
578
    std::ofstream profile_f;
579 580 581
    profile_f.open(profile_path,
                   std::ios::out | std::ios::trunc | std::ios::binary);
    profile_pb.SerializeToOstream(&profile_f);
X
Xin Pan 已提交
582
    profile_f.close();
583 584 585 586
    return profile_pb;
  }

  void Disable() {
Q
qiaolongfei 已提交
587
#ifdef PADDLE_WITH_CUPTI
588
    // flush might cause additional calls to DeviceTracker.
589 590
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
Q
qiaolongfei 已提交
591 592 593
#endif  // PADDLE_WITH_CUPTI
    std::lock_guard<std::mutex> l(trace_mu_);
#ifdef PADDLE_WITH_CUPTI
594
    DisableActivity();
595
    CUPTI_CALL(dynload::cuptiUnsubscribe(subscriber_));
596
    CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
Q
qiaolongfei 已提交
597
#endif  // PADDLE_WITH_CUPTI
598 599 600 601
    enabled_ = false;
  }

 private:
Q
qiaolongfei 已提交
602
#ifdef PADDLE_WITH_CUPTI
603 604 605
  static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
                                   CUpti_CallbackId cbid, const void *cbdata) {
    auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
606 607 608 609
    DeviceTracerImpl *tracer = reinterpret_cast<DeviceTracerImpl *>(userdata);
    if (cbInfo->callbackSite == CUPTI_API_ENTER) {
      Event *event = CurAnnotation();
      tracer->AddAnnotation(cbInfo->correlationId, event);
610 611 612 613
    }
  }
  CUpti_SubscriberHandle subscriber_;
#endif  // PADDLE_WITH_CUPTI
Q
qiaolongfei 已提交
614 615 616 617
  std::mutex trace_mu_;
  bool enabled_;
  uint64_t start_ns_;
  uint64_t end_ns_;
618 619 620
  std::forward_list<KernelRecord> kernel_records_;
  std::forward_list<MemRecord> mem_records_;
  std::forward_list<std::forward_list<CPURecord>> cpu_records_;
C
chengduo 已提交
621
  std::forward_list<std::forward_list<MemInfoRecord>> mem_info_record_;
622
  std::forward_list<std::forward_list<ActiveKindRecord>> active_kind_records_;
623 624 625
  std::forward_list<std::forward_list<std::pair<uint32_t, Event *>>>
      correlations_pairs;
  std::unordered_map<uint32_t, Event *> correlations_;
626 627
};

Q
qiaolongfei 已提交
628
void CreateTracer(DeviceTracer **t) { *t = new DeviceTracerImpl(); }
629 630 631 632 633 634

DeviceTracer *GetDeviceTracer() {
  std::call_once(tracer_once_flag, CreateTracer, &tracer);
  return tracer;
}

635 636 637 638 639 640 641
void SetCurAnnotation(Event *event) {
  if (!annotation_stack.empty()) {
    event->set_parent(annotation_stack.back());
    event->set_name(annotation_stack.back()->name() + "/" + event->name());
  }
  annotation_stack.push_back(event);
}
X
Xin Pan 已提交
642 643 644

void ClearCurAnnotation() { annotation_stack.pop_back(); }

645 646
Event *CurAnnotation() {
  if (annotation_stack.empty()) return nullptr;
X
Xin Pan 已提交
647 648
  return annotation_stack.back();
}
649
std::string CurAnnotationName() {
C
chengduo 已提交
650
  if (annotation_stack.empty()) return "Unknown";
651 652
  return annotation_stack.back()->name();
}
X
Xin Pan 已提交
653 654 655 656 657 658

void SetCurBlock(int block_id) { block_id_stack.push_back(block_id); }

void ClearCurBlock() { block_id_stack.pop_back(); }

int BlockDepth() { return block_id_stack.size(); }
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729

uint32_t GetCurSystemThreadId() {
  std::stringstream ss;
  ss << std::this_thread::get_id();
  uint32_t id = static_cast<uint32_t>(std::stoull(ss.str()));
  return id;
}

void RecoreCurThreadId(int32_t id) {
  auto gid = GetCurSystemThreadId();
  VLOG(1) << "RecoreCurThreadId: " << gid << " -> " << id;
  system_thread_id_map[gid] = id;
}

int32_t GetThreadIdFromSystemThreadId(uint32_t id) {
  auto it = system_thread_id_map.find(id);
  if (it != system_thread_id_map.end()) return it->second;
  // return origin id if no event is recorded in this thread.
  return static_cast<int32_t>(id);
}

#ifdef PADDLE_WITH_CUPTI
namespace {

void initCuptiCbidStr() {
  static bool called = false;
  if (called) return;
  called = true;
#define REGISTER_RUNTIME_CBID_STR(cbid) \
  runtime_cbid_str[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid

  REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFree_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020);
  REGISTER_RUNTIME_CBID_STR(
      cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020);
C
chengduo 已提交
730 731
  REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020);
732
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010);
733 734 735 736 737 738 739 740 741 742
#if CUDA_VERSION >= 9000
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000);
#endif

#undef REGISTER_RUNTIME_CBID_STR
}
}  // namespace
#endif  // PADDLE_WITH_CUPTI

743 744
}  // namespace platform
}  // namespace paddle