device_tracer.cc 29.5 KB
Newer Older
X
Xin Pan 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#include <deque>
16
#include <forward_list>
X
Xin Pan 已提交
17
#include <fstream>
18
#include <list>
19
#include <map>
20
#include <mutex>  // NOLINT
21
#include <numeric>
22
#include <sstream>
23 24
#include <string>
#include <thread>  // NOLINT
25 26
#include <unordered_map>
#include <utility>
27 28
#include <vector>

29
#include "glog/logging.h"
30
#include "google/protobuf/text_format.h"
31
#include "paddle/fluid/framework/block_desc.h"
C
chengduo 已提交
32 33
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/profiler.h"
34 35 36 37 38
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace platform {
namespace {
X
Xin Pan 已提交
39 40 41
// Tracking the nested block stacks of each thread.
thread_local std::deque<int> block_id_stack;
// Tracking the nested event stacks.
42
thread_local std::deque<Event *> annotation_stack;
W
wangchaochaohu 已提交
43 44 45
// stack to strore event sunch as pe and so on
static std::deque<Event *> main_thread_annotation_stack{};
static std::deque<std::string> main_thread_annotation_stack_name{};
46 47

std::map<uint32_t, int32_t> system_thread_id_map;
48 49 50

std::once_flag tracer_once_flag;
DeviceTracer *tracer = nullptr;
51 52 53 54 55

void PrintCuptiHint() {
  static bool showed = false;
  if (showed) return;
  showed = true;
T
tianshuo78520a 已提交
56
  LOG(WARNING) << "Invalid timestamp occurred. Please try increasing the "
57 58 59
                  "FLAGS_multiple_of_cupti_buffer_size.";
}

60 61 62 63
}  // namespace
#ifdef PADDLE_WITH_CUPTI

namespace {
64 65 66
// The experimental best performance is
// the same size with CUPTI device buffer size(8M)
uint64_t kBufSize = 1024 * 1024 * 8;
67
uint64_t kAlignSize = 8;
68 69
std::unordered_map<CUpti_CallbackId, std::string> runtime_cbid_str,
    driver_cbid_str;
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

#define ALIGN_BUFFER(buffer, align)                                 \
  (((uintptr_t)(buffer) & ((align)-1))                              \
       ? ((buffer) + (align) - ((uintptr_t)(buffer) & ((align)-1))) \
       : (buffer))

#define CUPTI_CALL(call)                                                   \
  do {                                                                     \
    CUptiResult _status = call;                                            \
    if (_status != CUPTI_SUCCESS) {                                        \
      const char *errstr;                                                  \
      dynload::cuptiGetResultString(_status, &errstr);                     \
      fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n", \
              __FILE__, __LINE__, #call, errstr);                          \
      exit(-1);                                                            \
    }                                                                      \
  } while (0)

X
Xin Pan 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
std::string MemcpyKind(CUpti_ActivityMemcpyKind kind) {
  switch (kind) {
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
      return "MEMCPY_HtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
      return "MEMCPY_DtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
      return "MEMCPY_HtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
      return "MEMCPY_AtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
      return "MEMCPY_AtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
      return "MEMCPY_AtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
      return "MEMCPY_DtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
      return "MEMCPY_DtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
      return "MEMCPY_HtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
      return "MEMCPY_PtoP";
    case CUPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT:
      return "MEMCPY_FORCE_INT";
    default:
      break;
  }
  return "MEMCPY";
}

118 119 120 121 122 123 124 125 126 127 128 129 130 131
std::string DriverKind(CUpti_CallbackId cbid) {
  auto iter = driver_cbid_str.find(cbid);
  if (iter == driver_cbid_str.end())
    return "Driver API " + std::to_string(cbid);
  return iter->second;
}

std::string RuntimeKind(CUpti_CallbackId cbid) {
  auto iter = runtime_cbid_str.find(cbid);
  if (iter == runtime_cbid_str.end())
    return "Runtime API " + std::to_string(cbid);
  return iter->second;
}

132 133 134 135
void EnableActivity() {
  // Device activity record is created when CUDA initializes, so we
  // want to enable it before cuInit() or any CUDA runtime call.
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
136 137 138 139 140
  CUPTI_CALL(
      dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
141
  // We don't track these activities for now.
D
Dun 已提交
142
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
143 144
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE));
145 146 147 148 149 150 151 152 153
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
}

void DisableActivity() {
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
154 155 156
  CUPTI_CALL(
      dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE));
157
  // Disable all other activity record kinds.
158
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT));
159 160
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
D
Dun 已提交
161
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
162 163 164
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MARKER));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD));
165 166 167 168
}

void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size,
                              size_t *maxNumRecords) {
169
  uint8_t *buf = reinterpret_cast<uint8_t *>(malloc(kBufSize + kAlignSize));
170 171 172 173 174 175 176
  *size = kBufSize;
  *buffer = ALIGN_BUFFER(buf, kAlignSize);
  *maxNumRecords = 0;
}

void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
                              size_t size, size_t validSize) {
177 178 179 180 181
  static std::thread::id cupti_thread_id(0);
  if (cupti_thread_id == std::thread::id(0))
    cupti_thread_id = std::this_thread::get_id();
  PADDLE_ENFORCE_EQ(std::this_thread::get_id(), cupti_thread_id,
                    "Only one thread is allowed to call bufferCompleted()");
182 183 184 185 186 187 188 189 190
  CUptiResult status;
  CUpti_Activity *record = NULL;
  if (validSize > 0) {
    do {
      status = dynload::cuptiActivityGetNextRecord(buffer, validSize, &record);
      if (status == CUPTI_SUCCESS) {
        switch (record->kind) {
          case CUPTI_ACTIVITY_KIND_KERNEL:
          case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
W
wangchaochaohu 已提交
191 192 193 194
#if CUDA_VERSION >= 9000
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel4 *>(record);
#else
195 196
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
W
wangchaochaohu 已提交
197
#endif
Z
ZongwuYang 已提交
198
            tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
199 200 201 202
                                     kernel->deviceId, kernel->streamId,
                                     kernel->correlationId);
            break;
          }
X
Xin Pan 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
          case CUPTI_ACTIVITY_KIND_MEMCPY: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
          case CUPTI_ACTIVITY_KIND_MEMCPY2: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy2 *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
D
Dun 已提交
223 224 225 226 227 228 229 230
          case CUPTI_ACTIVITY_KIND_MEMSET: {
            auto *memset =
                reinterpret_cast<const CUpti_ActivityMemset *>(record);
            tracer->AddKernelRecords("MEMSET", memset->start, memset->end,
                                     memset->deviceId, memset->streamId,
                                     memset->correlationId);
            break;
          }
231 232
          case CUPTI_ACTIVITY_KIND_DRIVER: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
233 234 235
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
236
                  DriverKind(api->cbid), api->start, api->end, -1,
237 238 239
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
240 241 242 243
            break;
          }
          case CUPTI_ACTIVITY_KIND_RUNTIME: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
244 245 246
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
247
                  RuntimeKind(api->cbid), api->start, api->end, -1,
248 249 250
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
251 252
            break;
          }
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
          default: { break; }
        }
      } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
        // Seems not an error in this case.
        break;
      } else {
        CUPTI_CALL(status);
      }
    } while (1);

    size_t dropped;
    CUPTI_CALL(
        dynload::cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped));
    if (dropped != 0) {
      fprintf(stderr, "Dropped %u activity records\n", (unsigned int)dropped);
268
      PrintCuptiHint();
269 270 271 272
    }
  }
  free(buffer);
}
273 274 275

void initCuptiCbidStr();

276 277
}  // namespace

Q
qiaolongfei 已提交
278 279
#endif  // PADDLE_WITH_CUPTI

280 281
class DeviceTracerImpl : public DeviceTracer {
 public:
282 283 284 285 286
  DeviceTracerImpl() : enabled_(false) {
#ifdef PADDLE_WITH_CUPTI
    initCuptiCbidStr();
#endif
  }
287

288 289 290 291 292 293 294 295 296
  void AddAnnotation(uint32_t id, Event *event) {
    thread_local std::forward_list<std::pair<uint32_t, Event *>>
        *local_correlations_pairs = nullptr;
    if (local_correlations_pairs == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      correlations_pairs.emplace_front();
      local_correlations_pairs = &correlations_pairs.front();
    }
    local_correlations_pairs->push_front(std::make_pair(id, event));
297 298
  }

X
Xin Pan 已提交
299 300 301
  void AddCPURecords(const std::string &anno, uint64_t start_ns,
                     uint64_t end_ns, int64_t device_id, int64_t thread_id) {
    if (anno.empty()) {
M
minqiyang 已提交
302
      VLOG(1) << "Empty timeline annotation.";
303 304
      return;
    }
305 306 307 308 309 310 311
    thread_local std::forward_list<CPURecord> *local_cpu_records_ = nullptr;
    if (local_cpu_records_ == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      cpu_records_.emplace_front();
      local_cpu_records_ = &cpu_records_.front();
    }
    local_cpu_records_->push_front(
X
Xin Pan 已提交
312
        CPURecord{anno, start_ns, end_ns, device_id, thread_id});
X
Xin Pan 已提交
313 314
  }

X
Xin Pan 已提交
315
  void AddMemRecords(const std::string &name, uint64_t start_ns,
X
Xin Pan 已提交
316
                     uint64_t end_ns, int64_t device_id, int64_t stream_id,
X
Xin Pan 已提交
317
                     uint32_t correlation_id, uint64_t bytes) {
X
Xin Pan 已提交
318
    // 0 means timestamp information could not be collected for the kernel.
319
    if (start_ns == 0 || end_ns == 0 || start_ns == end_ns) {
M
minqiyang 已提交
320
      VLOG(3) << name << " cannot be traced";
321
      PrintCuptiHint();
X
Xin Pan 已提交
322 323
      return;
    }
324 325 326
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    mem_records_.push_front(MemRecord{name, start_ns, end_ns, device_id,
                                      stream_id, correlation_id, bytes});
X
Xin Pan 已提交
327 328
  }

C
chengduo 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
  void AddMemInfoRecord(uint64_t start_ns, uint64_t end_ns, size_t bytes,
                        const Place &place, const std::string &alloc_in,
                        const std::string &free_in, int64_t thread_id) {
    if (0 == start_ns || 0 == end_ns) {
      VLOG(3) << alloc_in << ", " << free_in << " Cannot be traced.";
      return;
    }
    thread_local std::forward_list<MemInfoRecord> *local_mem_info_record =
        nullptr;
    if (local_mem_info_record == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      mem_info_record_.emplace_front();
      local_mem_info_record = &mem_info_record_.front();
    }
    local_mem_info_record->emplace_front(MemInfoRecord{
        start_ns, end_ns, bytes, place, thread_id, alloc_in, free_in});
  }

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  void AddActiveKindRecords(const std::string &anno, uint64_t start_ns,
                            uint64_t end_ns, int64_t device_id,
                            int64_t thread_id, uint32_t correlation_id) {
    if (anno.empty()) {
      VLOG(1) << "Empty timeline annotation.";
      return;
    }
    thread_local std::forward_list<ActiveKindRecord>
        *local_active_kind_records = nullptr;
    if (local_active_kind_records == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      active_kind_records_.emplace_front();
      local_active_kind_records = &active_kind_records_.front();
    }
    //  lock is not needed, only one thread call this function.
    local_active_kind_records->push_front(ActiveKindRecord{
        anno, start_ns, end_ns, device_id, thread_id, correlation_id});
  }

Z
ZongwuYang 已提交
366 367 368
  void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
                        int64_t device_id, int64_t stream_id,
                        uint32_t correlation_id) {
X
Xin Pan 已提交
369
    // 0 means timestamp information could not be collected for the kernel.
370
    if (start == 0 || end == 0 || start == end) {
M
minqiyang 已提交
371
      VLOG(3) << correlation_id << " cannot be traced";
372
      PrintCuptiHint();
X
Xin Pan 已提交
373 374
      return;
    }
375 376
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    kernel_records_.push_front(
Z
ZongwuYang 已提交
377
        KernelRecord{name, start, end, device_id, stream_id, correlation_id});
378 379 380 381 382 383 384 385 386 387 388 389
  }

  bool IsEnabled() {
    std::lock_guard<std::mutex> l(trace_mu_);
    return enabled_;
  }

  void Enable() {
    std::lock_guard<std::mutex> l(trace_mu_);
    if (enabled_) {
      return;
    }
Q
qiaolongfei 已提交
390 391

#ifdef PADDLE_WITH_CUPTI
392 393 394 395 396 397 398 399 400 401 402 403 404 405
    EnableActivity();

    // Register callbacks for buffer requests and completed by CUPTI.
    CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(bufferRequested,
                                                       bufferCompleted));

    CUptiResult ret;
    ret = dynload::cuptiSubscribe(
        &subscriber_, static_cast<CUpti_CallbackFunc>(ApiCallback), this);
    if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) {
      fprintf(stderr, "CUPTI subcriber limit reached.\n");
    } else if (ret != CUPTI_SUCCESS) {
      fprintf(stderr, "Failed to create CUPTI subscriber.\n");
    }
406
    const std::vector<int> runtime_cbids {
407
      CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020,
408
          CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020,
409
          CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020,
D
Dun 已提交
410 411
          CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020,
412 413 414 415 416 417 418 419
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000
#if CUDA_VERSION >= 9000
          ,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000
#endif
    };
420 421 422 423
    const std::vector<int> driver_cbids{CUPTI_DRIVER_TRACE_CBID_cuLaunch,
                                        CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid,
                                        CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel};
    for (auto cbid : runtime_cbids)
424 425
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, cbid));
426 427 428
    for (auto cbid : driver_cbids)
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid));
429
    CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
Q
qiaolongfei 已提交
430
#endif  // PADDLE_WITH_CUPTI
431 432 433
    enabled_ = true;
  }

434 435 436 437 438 439 440 441 442 443 444
  void Reset() {
#ifdef PADDLE_WITH_CUPTI
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
#endif
    std::lock_guard<std::mutex> l(trace_mu_);
    kernel_records_.clear();
    mem_records_.clear();
    correlations_.clear();
    for (auto &tmp : correlations_pairs) tmp.clear();
    for (auto &tmp : cpu_records_) tmp.clear();
C
chengduo 已提交
445
    for (auto &tmp : mem_info_record_) tmp.clear();
446
    for (auto &tmp : active_kind_records_) tmp.clear();
447 448 449 450 451 452 453 454 455 456 457
  }

  void GenEventKernelCudaElapsedTime() {
#ifdef PADDLE_WITH_CUPTI
    if (correlations_.empty())
      for (auto &tmp : correlations_pairs)
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
    for (const KernelRecord &r : kernel_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
458 459 460 461 462
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
463 464 465 466 467 468 469
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
    for (const auto &r : mem_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
470 471 472 473 474
        Event *parent = e->parent();
        while (parent) {
          parent->AddCudaElapsedTime(r.start_ns, r.end_ns);
          parent = parent->parent();
        }
475 476 477 478 479 480
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
#endif
  }

X
Xin Pan 已提交
481
  proto::Profile GenProfile(const std::string &profile_path) {
482
    int miss = 0, find = 0;
483 484 485 486
    std::lock_guard<std::mutex> l(trace_mu_);
    proto::Profile profile_pb;
    profile_pb.set_start_ns(start_ns_);
    profile_pb.set_end_ns(end_ns_);
C
chengduo 已提交
487 488
    if (correlations_.empty()) {
      for (auto &tmp : correlations_pairs) {
489
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
C
chengduo 已提交
490 491 492
      }
    }

493 494
    for (const KernelRecord &r : kernel_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
495
      event->set_type(proto::Event::GPUKernel);
496 497 498 499 500
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
Z
ZongwuYang 已提交
501
      } else {
502 503
        VLOG(10) << "Missing Kernel Event: " + r.name;
        miss++;
Z
ZongwuYang 已提交
504 505
        event->set_name(r.name);
      }
506 507
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
508
      event->set_sub_device_id(r.stream_id);
509
      event->set_device_id(r.device_id);
X
Xin Pan 已提交
510
    }
511
    VLOG(1) << "KernelRecord event miss: " << miss << " find: " << find;
C
chengduo 已提交
512

513
    for (auto &tmp : cpu_records_) {
514 515 516 517 518 519 520 521 522
      for (const CPURecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        event->set_name(r.name);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
523
    }
C
chengduo 已提交
524

525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
    for (auto &tmp : active_kind_records_) {
      for (const ActiveKindRecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        auto c = correlations_.find(r.correlation_id);
        if (c != correlations_.end() && c->second != nullptr) {
          event->set_name(c->second->name());
          event->set_detail_info(r.name);
        } else {
          event->set_name(r.name);
        }
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
    }
542
    miss = find = 0;
X
Xin Pan 已提交
543 544
    for (const MemRecord &r : mem_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
545
      event->set_type(proto::Event::GPUKernel);
546 547 548 549 550 551 552 553 554
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
      } else {
        miss++;
        event->set_name(r.name);
      }
X
Xin Pan 已提交
555 556
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
557
      event->set_sub_device_id(r.stream_id);
X
Xin Pan 已提交
558 559 560
      event->set_device_id(r.device_id);
      event->mutable_memcopy()->set_bytes(r.bytes);
    }
561
    VLOG(1) << "MemRecord event miss: " << miss << " find: " << find;
C
chengduo 已提交
562 563 564 565 566 567 568 569 570 571

    for (auto &tmp : mem_info_record_) {
      for (const auto &r : tmp) {
        auto *event = profile_pb.add_mem_events();
        event->set_device_id(0);
        if (platform::is_cpu_place(r.place)) {
          event->set_place(proto::MemEvent::CPUPlace);
        } else if (platform::is_gpu_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPlace);
          event->set_device_id(
572
              BOOST_GET_CONST(platform::CUDAPlace, r.place).GetDeviceId());
C
chengduo 已提交
573 574 575 576 577 578 579 580 581 582 583 584 585 586
        } else if (platform::is_cuda_pinned_place(r.place)) {
          event->set_place(proto::MemEvent::CUDAPinnedPlace);
        } else {
          PADDLE_THROW("The current place is not supported.");
        }
        event->set_alloc_in(r.alloc_in);
        event->set_free_in(r.free_in);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_bytes(r.bytes);
        event->set_thread_id(r.thread_id);
      }
    }

X
Xin Pan 已提交
587
    std::ofstream profile_f;
588 589 590
    profile_f.open(profile_path,
                   std::ios::out | std::ios::trunc | std::ios::binary);
    profile_pb.SerializeToOstream(&profile_f);
X
Xin Pan 已提交
591
    profile_f.close();
592 593 594 595
    return profile_pb;
  }

  void Disable() {
Q
qiaolongfei 已提交
596
#ifdef PADDLE_WITH_CUPTI
597
    // flush might cause additional calls to DeviceTracker.
598 599
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
Q
qiaolongfei 已提交
600 601 602
#endif  // PADDLE_WITH_CUPTI
    std::lock_guard<std::mutex> l(trace_mu_);
#ifdef PADDLE_WITH_CUPTI
603
    DisableActivity();
604
    CUPTI_CALL(dynload::cuptiUnsubscribe(subscriber_));
605
    CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
Q
qiaolongfei 已提交
606
#endif  // PADDLE_WITH_CUPTI
607 608 609 610
    enabled_ = false;
  }

 private:
Q
qiaolongfei 已提交
611
#ifdef PADDLE_WITH_CUPTI
612 613 614
  static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
                                   CUpti_CallbackId cbid, const void *cbdata) {
    auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
615 616 617 618
    DeviceTracerImpl *tracer = reinterpret_cast<DeviceTracerImpl *>(userdata);
    if (cbInfo->callbackSite == CUPTI_API_ENTER) {
      Event *event = CurAnnotation();
      tracer->AddAnnotation(cbInfo->correlationId, event);
619 620 621 622
    }
  }
  CUpti_SubscriberHandle subscriber_;
#endif  // PADDLE_WITH_CUPTI
Q
qiaolongfei 已提交
623 624 625 626
  std::mutex trace_mu_;
  bool enabled_;
  uint64_t start_ns_;
  uint64_t end_ns_;
627 628 629
  std::forward_list<KernelRecord> kernel_records_;
  std::forward_list<MemRecord> mem_records_;
  std::forward_list<std::forward_list<CPURecord>> cpu_records_;
C
chengduo 已提交
630
  std::forward_list<std::forward_list<MemInfoRecord>> mem_info_record_;
631
  std::forward_list<std::forward_list<ActiveKindRecord>> active_kind_records_;
632 633 634
  std::forward_list<std::forward_list<std::pair<uint32_t, Event *>>>
      correlations_pairs;
  std::unordered_map<uint32_t, Event *> correlations_;
635 636
};

Q
qiaolongfei 已提交
637
void CreateTracer(DeviceTracer **t) { *t = new DeviceTracerImpl(); }
638 639 640 641 642 643

DeviceTracer *GetDeviceTracer() {
  std::call_once(tracer_once_flag, CreateTracer, &tracer);
  return tracer;
}

644 645 646 647 648
// In order to record PE time, we add main_thread_annotation_stack
// for all event between PE run, we treat it as PE's child Event,
// so when event is not in same thread of PE event, we need add
// father event(PE::run event) for this event
void SetCurAnnotation(Event *event) {
W
wangchaochaohu 已提交
649
  std::string ret;
650
  if (!annotation_stack.empty()) {
651 652 653
    event->set_parent(annotation_stack.back());
    event->set_name(annotation_stack.back()->name() + "/" + event->name());
  }
654 655 656 657 658 659
  if (annotation_stack.empty() && !main_thread_annotation_stack.empty() &&
      main_thread_annotation_stack.back()->thread_id() != event->thread_id()) {
    event->set_parent(main_thread_annotation_stack.back());
    event->set_name(main_thread_annotation_stack.back()->name() + "/" +
                    event->name());
  }
660
  annotation_stack.push_back(event);
W
wangchaochaohu 已提交
661 662 663 664 665 666 667 668 669

  if (event->role() == EventRole::kSpecial) {
    std::string name = event->name();
    if (!main_thread_annotation_stack_name.empty()) {
      name = main_thread_annotation_stack_name.back() + "/" + event->name();
    }
    main_thread_annotation_stack_name.push_back(name);
    main_thread_annotation_stack.push_back(event);
  }
670
}
X
Xin Pan 已提交
671

W
wangchaochaohu 已提交
672 673 674 675
void ClearCurAnnotation() {
  if (!main_thread_annotation_stack.empty() &&
      main_thread_annotation_stack.back()->name() ==
          annotation_stack.back()->name()) {
676 677 678 679 680 681 682 683 684 685 686
    std::string name = annotation_stack.back()->name();
    std::string main_name = main_thread_annotation_stack.back()->name();
    int main_name_len = main_name.length();
    int name_len = name.length();
    int prefix_len = main_name_len - name_len;

    if (prefix_len >= 0 && main_name.at(prefix_len) == '/' &&
        name == main_name.substr(prefix_len, name_len)) {
      main_thread_annotation_stack_name.pop_back();
      main_thread_annotation_stack.pop_back();
    }
W
wangchaochaohu 已提交
687 688 689
  }
  annotation_stack.pop_back();
}
X
Xin Pan 已提交
690

691 692
Event *CurAnnotation() {
  if (annotation_stack.empty()) return nullptr;
X
Xin Pan 已提交
693 694
  return annotation_stack.back();
}
695
std::string CurAnnotationName() {
C
chengduo 已提交
696
  if (annotation_stack.empty()) return "Unknown";
697 698
  return annotation_stack.back()->name();
}
X
Xin Pan 已提交
699 700 701 702 703 704

void SetCurBlock(int block_id) { block_id_stack.push_back(block_id); }

void ClearCurBlock() { block_id_stack.pop_back(); }

int BlockDepth() { return block_id_stack.size(); }
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775

uint32_t GetCurSystemThreadId() {
  std::stringstream ss;
  ss << std::this_thread::get_id();
  uint32_t id = static_cast<uint32_t>(std::stoull(ss.str()));
  return id;
}

void RecoreCurThreadId(int32_t id) {
  auto gid = GetCurSystemThreadId();
  VLOG(1) << "RecoreCurThreadId: " << gid << " -> " << id;
  system_thread_id_map[gid] = id;
}

int32_t GetThreadIdFromSystemThreadId(uint32_t id) {
  auto it = system_thread_id_map.find(id);
  if (it != system_thread_id_map.end()) return it->second;
  // return origin id if no event is recorded in this thread.
  return static_cast<int32_t>(id);
}

#ifdef PADDLE_WITH_CUPTI
namespace {

void initCuptiCbidStr() {
  static bool called = false;
  if (called) return;
  called = true;
#define REGISTER_RUNTIME_CBID_STR(cbid) \
  runtime_cbid_str[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid

  REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFree_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020);
  REGISTER_RUNTIME_CBID_STR(
      cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020);
C
chengduo 已提交
776 777
  REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020);
778
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010);
779 780 781 782 783 784 785 786 787 788
#if CUDA_VERSION >= 9000
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000);
#endif

#undef REGISTER_RUNTIME_CBID_STR
}
}  // namespace
#endif  // PADDLE_WITH_CUPTI

789 790
}  // namespace platform
}  // namespace paddle