device_tracer.cc 24.9 KB
Newer Older
X
Xin Pan 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
15 16

#include <deque>
17
#include <forward_list>
X
Xin Pan 已提交
18
#include <fstream>
19
#include <list>
20
#include <map>
21
#include <mutex>  // NOLINT
22
#include <numeric>
23
#include <sstream>
24 25
#include <string>
#include <thread>  // NOLINT
26 27
#include <unordered_map>
#include <utility>
28 29
#include <vector>

30
#include "glog/logging.h"
31
#include "google/protobuf/text_format.h"
32 33 34 35 36 37
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace platform {
namespace {
X
Xin Pan 已提交
38 39 40
// Tracking the nested block stacks of each thread.
thread_local std::deque<int> block_id_stack;
// Tracking the nested event stacks.
41 42 43
thread_local std::deque<Event *> annotation_stack;

std::map<uint32_t, int32_t> system_thread_id_map;
44 45 46

std::once_flag tracer_once_flag;
DeviceTracer *tracer = nullptr;
47 48 49 50 51 52 53 54 55

void PrintCuptiHint() {
  static bool showed = false;
  if (showed) return;
  showed = true;
  LOG(WARNING) << "Invalid timestamp occured. Please try increasing the "
                  "FLAGS_multiple_of_cupti_buffer_size.";
}

56 57 58 59
}  // namespace
#ifdef PADDLE_WITH_CUPTI

namespace {
60 61 62
// The experimental best performance is
// the same size with CUPTI device buffer size(8M)
uint64_t kBufSize = 1024 * 1024 * 8;
63
uint64_t kAlignSize = 8;
64 65
std::unordered_map<CUpti_CallbackId, std::string> runtime_cbid_str,
    driver_cbid_str;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

#define ALIGN_BUFFER(buffer, align)                                 \
  (((uintptr_t)(buffer) & ((align)-1))                              \
       ? ((buffer) + (align) - ((uintptr_t)(buffer) & ((align)-1))) \
       : (buffer))

#define CUPTI_CALL(call)                                                   \
  do {                                                                     \
    CUptiResult _status = call;                                            \
    if (_status != CUPTI_SUCCESS) {                                        \
      const char *errstr;                                                  \
      dynload::cuptiGetResultString(_status, &errstr);                     \
      fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n", \
              __FILE__, __LINE__, #call, errstr);                          \
      exit(-1);                                                            \
    }                                                                      \
  } while (0)

X
Xin Pan 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
std::string MemcpyKind(CUpti_ActivityMemcpyKind kind) {
  switch (kind) {
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
      return "MEMCPY_HtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
      return "MEMCPY_DtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
      return "MEMCPY_HtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
      return "MEMCPY_AtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
      return "MEMCPY_AtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
      return "MEMCPY_AtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
      return "MEMCPY_DtoA";
    case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
      return "MEMCPY_DtoD";
    case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
      return "MEMCPY_HtoH";
    case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
      return "MEMCPY_PtoP";
    case CUPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT:
      return "MEMCPY_FORCE_INT";
    default:
      break;
  }
  return "MEMCPY";
}

114 115 116 117 118 119 120 121 122 123 124 125 126 127
std::string DriverKind(CUpti_CallbackId cbid) {
  auto iter = driver_cbid_str.find(cbid);
  if (iter == driver_cbid_str.end())
    return "Driver API " + std::to_string(cbid);
  return iter->second;
}

std::string RuntimeKind(CUpti_CallbackId cbid) {
  auto iter = runtime_cbid_str.find(cbid);
  if (iter == runtime_cbid_str.end())
    return "Runtime API " + std::to_string(cbid);
  return iter->second;
}

128 129 130 131
void EnableActivity() {
  // Device activity record is created when CUDA initializes, so we
  // want to enable it before cuInit() or any CUDA runtime call.
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
132 133 134 135 136
  CUPTI_CALL(
      dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
137
  // We don't track these activities for now.
D
Dun 已提交
138
  CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
139 140
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE));
141 142 143 144 145 146 147 148 149
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
}

void DisableActivity() {
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
150 151 152
  CUPTI_CALL(
      dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE));
153
  // Disable all other activity record kinds.
154
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT));
155 156
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
D
Dun 已提交
157
  CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
158 159 160
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_NAME));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MARKER));
  // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD));
161 162 163 164
}

void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size,
                              size_t *maxNumRecords) {
165
  uint8_t *buf = reinterpret_cast<uint8_t *>(malloc(kBufSize + kAlignSize));
166 167 168 169 170 171 172
  *size = kBufSize;
  *buffer = ALIGN_BUFFER(buf, kAlignSize);
  *maxNumRecords = 0;
}

void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
                              size_t size, size_t validSize) {
173 174 175 176 177
  static std::thread::id cupti_thread_id(0);
  if (cupti_thread_id == std::thread::id(0))
    cupti_thread_id = std::this_thread::get_id();
  PADDLE_ENFORCE_EQ(std::this_thread::get_id(), cupti_thread_id,
                    "Only one thread is allowed to call bufferCompleted()");
178 179 180 181 182 183 184 185 186 187 188
  CUptiResult status;
  CUpti_Activity *record = NULL;
  if (validSize > 0) {
    do {
      status = dynload::cuptiActivityGetNextRecord(buffer, validSize, &record);
      if (status == CUPTI_SUCCESS) {
        switch (record->kind) {
          case CUPTI_ACTIVITY_KIND_KERNEL:
          case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
            auto *kernel =
                reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
Z
ZongwuYang 已提交
189
            tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
190 191 192 193
                                     kernel->deviceId, kernel->streamId,
                                     kernel->correlationId);
            break;
          }
X
Xin Pan 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
          case CUPTI_ACTIVITY_KIND_MEMCPY: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
          case CUPTI_ACTIVITY_KIND_MEMCPY2: {
            auto *memcpy =
                reinterpret_cast<const CUpti_ActivityMemcpy2 *>(record);
            tracer->AddMemRecords(
                MemcpyKind(
                    static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
                memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
                memcpy->correlationId, memcpy->bytes);
            break;
          }
D
Dun 已提交
214 215 216 217 218 219 220 221
          case CUPTI_ACTIVITY_KIND_MEMSET: {
            auto *memset =
                reinterpret_cast<const CUpti_ActivityMemset *>(record);
            tracer->AddKernelRecords("MEMSET", memset->start, memset->end,
                                     memset->deviceId, memset->streamId,
                                     memset->correlationId);
            break;
          }
222 223
          case CUPTI_ACTIVITY_KIND_DRIVER: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
224 225 226
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
227
                  DriverKind(api->cbid), api->start, api->end, -1,
228 229 230
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
231 232 233 234
            break;
          }
          case CUPTI_ACTIVITY_KIND_RUNTIME: {
            auto *api = reinterpret_cast<const CUpti_ActivityAPI *>(record);
235 236 237
            if (api->start != 0 && api->end != 0) {
              // -1 device id represents ActiveKind api call
              tracer->AddActiveKindRecords(
238
                  RuntimeKind(api->cbid), api->start, api->end, -1,
239 240 241
                  GetThreadIdFromSystemThreadId(api->threadId),
                  api->correlationId);
            }
242 243
            break;
          }
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
          default: { break; }
        }
      } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
        // Seems not an error in this case.
        break;
      } else {
        CUPTI_CALL(status);
      }
    } while (1);

    size_t dropped;
    CUPTI_CALL(
        dynload::cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped));
    if (dropped != 0) {
      fprintf(stderr, "Dropped %u activity records\n", (unsigned int)dropped);
259
      PrintCuptiHint();
260 261 262 263
    }
  }
  free(buffer);
}
264 265 266

void initCuptiCbidStr();

267 268
}  // namespace

Q
qiaolongfei 已提交
269 270
#endif  // PADDLE_WITH_CUPTI

271 272
class DeviceTracerImpl : public DeviceTracer {
 public:
273 274 275 276 277
  DeviceTracerImpl() : enabled_(false) {
#ifdef PADDLE_WITH_CUPTI
    initCuptiCbidStr();
#endif
  }
278

279 280 281 282 283 284 285 286 287
  void AddAnnotation(uint32_t id, Event *event) {
    thread_local std::forward_list<std::pair<uint32_t, Event *>>
        *local_correlations_pairs = nullptr;
    if (local_correlations_pairs == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      correlations_pairs.emplace_front();
      local_correlations_pairs = &correlations_pairs.front();
    }
    local_correlations_pairs->push_front(std::make_pair(id, event));
288 289
  }

X
Xin Pan 已提交
290 291 292
  void AddCPURecords(const std::string &anno, uint64_t start_ns,
                     uint64_t end_ns, int64_t device_id, int64_t thread_id) {
    if (anno.empty()) {
M
minqiyang 已提交
293
      VLOG(1) << "Empty timeline annotation.";
294 295
      return;
    }
296 297 298 299 300 301 302
    thread_local std::forward_list<CPURecord> *local_cpu_records_ = nullptr;
    if (local_cpu_records_ == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      cpu_records_.emplace_front();
      local_cpu_records_ = &cpu_records_.front();
    }
    local_cpu_records_->push_front(
X
Xin Pan 已提交
303
        CPURecord{anno, start_ns, end_ns, device_id, thread_id});
X
Xin Pan 已提交
304 305
  }

X
Xin Pan 已提交
306
  void AddMemRecords(const std::string &name, uint64_t start_ns,
X
Xin Pan 已提交
307
                     uint64_t end_ns, int64_t device_id, int64_t stream_id,
X
Xin Pan 已提交
308
                     uint32_t correlation_id, uint64_t bytes) {
X
Xin Pan 已提交
309
    // 0 means timestamp information could not be collected for the kernel.
310
    if (start_ns == 0 || end_ns == 0 || start_ns == end_ns) {
M
minqiyang 已提交
311
      VLOG(3) << name << " cannot be traced";
312
      PrintCuptiHint();
X
Xin Pan 已提交
313 314
      return;
    }
315 316 317
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    mem_records_.push_front(MemRecord{name, start_ns, end_ns, device_id,
                                      stream_id, correlation_id, bytes});
X
Xin Pan 已提交
318 319
  }

320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
  void AddActiveKindRecords(const std::string &anno, uint64_t start_ns,
                            uint64_t end_ns, int64_t device_id,
                            int64_t thread_id, uint32_t correlation_id) {
    if (anno.empty()) {
      VLOG(1) << "Empty timeline annotation.";
      return;
    }
    thread_local std::forward_list<ActiveKindRecord>
        *local_active_kind_records = nullptr;
    if (local_active_kind_records == nullptr) {
      std::lock_guard<std::mutex> l(trace_mu_);
      active_kind_records_.emplace_front();
      local_active_kind_records = &active_kind_records_.front();
    }
    //  lock is not needed, only one thread call this function.
    local_active_kind_records->push_front(ActiveKindRecord{
        anno, start_ns, end_ns, device_id, thread_id, correlation_id});
  }

Z
ZongwuYang 已提交
339 340 341
  void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
                        int64_t device_id, int64_t stream_id,
                        uint32_t correlation_id) {
X
Xin Pan 已提交
342
    // 0 means timestamp information could not be collected for the kernel.
343
    if (start == 0 || end == 0 || start == end) {
M
minqiyang 已提交
344
      VLOG(3) << correlation_id << " cannot be traced";
345
      PrintCuptiHint();
X
Xin Pan 已提交
346 347
      return;
    }
348 349
    // NOTE(liangdun): lock is not needed, only one thread call this function.
    kernel_records_.push_front(
Z
ZongwuYang 已提交
350
        KernelRecord{name, start, end, device_id, stream_id, correlation_id});
351 352 353 354 355 356 357 358 359 360 361 362
  }

  bool IsEnabled() {
    std::lock_guard<std::mutex> l(trace_mu_);
    return enabled_;
  }

  void Enable() {
    std::lock_guard<std::mutex> l(trace_mu_);
    if (enabled_) {
      return;
    }
Q
qiaolongfei 已提交
363 364

#ifdef PADDLE_WITH_CUPTI
365 366 367 368 369 370 371 372 373 374 375 376 377 378
    EnableActivity();

    // Register callbacks for buffer requests and completed by CUPTI.
    CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(bufferRequested,
                                                       bufferCompleted));

    CUptiResult ret;
    ret = dynload::cuptiSubscribe(
        &subscriber_, static_cast<CUpti_CallbackFunc>(ApiCallback), this);
    if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) {
      fprintf(stderr, "CUPTI subcriber limit reached.\n");
    } else if (ret != CUPTI_SUCCESS) {
      fprintf(stderr, "Failed to create CUPTI subscriber.\n");
    }
379 380
    const std::vector<int> cbids {
      CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020,
381
          CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020,
382
          CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020,
D
Dun 已提交
383 384
          CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020,
385 386 387 388 389 390 391 392 393 394 395
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000
#if CUDA_VERSION >= 9000
          ,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000,
          CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000
#endif
    };
    for (auto cbid : cbids)
      CUPTI_CALL(dynload::cuptiEnableCallback(
          1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, cbid));
396
    CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
Q
qiaolongfei 已提交
397
#endif  // PADDLE_WITH_CUPTI
398 399 400
    enabled_ = true;
  }

401 402 403 404 405 406 407 408 409 410 411
  void Reset() {
#ifdef PADDLE_WITH_CUPTI
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
#endif
    std::lock_guard<std::mutex> l(trace_mu_);
    kernel_records_.clear();
    mem_records_.clear();
    correlations_.clear();
    for (auto &tmp : correlations_pairs) tmp.clear();
    for (auto &tmp : cpu_records_) tmp.clear();
412
    for (auto &tmp : active_kind_records_) tmp.clear();
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
  }

  void GenEventKernelCudaElapsedTime() {
#ifdef PADDLE_WITH_CUPTI
    if (correlations_.empty())
      for (auto &tmp : correlations_pairs)
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
    for (const KernelRecord &r : kernel_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
    for (const auto &r : mem_records_) {
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        Event *e = c->second;
        e->AddCudaElapsedTime(r.start_ns, r.end_ns);
      }
    }
#endif
  }

X
Xin Pan 已提交
437
  proto::Profile GenProfile(const std::string &profile_path) {
438
    int miss = 0, find = 0;
439 440 441 442
    std::lock_guard<std::mutex> l(trace_mu_);
    proto::Profile profile_pb;
    profile_pb.set_start_ns(start_ns_);
    profile_pb.set_end_ns(end_ns_);
443 444 445
    if (correlations_.empty())
      for (auto &tmp : correlations_pairs)
        for (auto &pair : tmp) correlations_[pair.first] = pair.second;
446 447
    for (const KernelRecord &r : kernel_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
448
      event->set_type(proto::Event::GPUKernel);
449 450 451 452 453
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
Z
ZongwuYang 已提交
454
      } else {
455 456
        VLOG(10) << "Missing Kernel Event: " + r.name;
        miss++;
Z
ZongwuYang 已提交
457 458
        event->set_name(r.name);
      }
459 460
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
461
      event->set_sub_device_id(r.stream_id);
462
      event->set_device_id(r.device_id);
X
Xin Pan 已提交
463
    }
464
    VLOG(1) << "KernelRecord event miss: " << miss << " find: " << find;
465
    for (auto &tmp : cpu_records_) {
466 467 468 469 470 471 472 473 474
      for (const CPURecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        event->set_name(r.name);
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
    }
    for (auto &tmp : active_kind_records_) {
      for (const ActiveKindRecord &r : tmp) {
        auto *event = profile_pb.add_events();
        event->set_type(proto::Event::CPU);
        auto c = correlations_.find(r.correlation_id);
        if (c != correlations_.end() && c->second != nullptr) {
          event->set_name(c->second->name());
          event->set_detail_info(r.name);
        } else {
          event->set_name(r.name);
        }
        event->set_start_ns(r.start_ns);
        event->set_end_ns(r.end_ns);
        event->set_sub_device_id(r.thread_id);
        event->set_device_id(r.device_id);
      }
    }
493
    miss = find = 0;
X
Xin Pan 已提交
494 495
    for (const MemRecord &r : mem_records_) {
      auto *event = profile_pb.add_events();
X
Xin Pan 已提交
496
      event->set_type(proto::Event::GPUKernel);
497 498 499 500 501 502 503 504 505
      auto c = correlations_.find(r.correlation_id);
      if (c != correlations_.end() && c->second != nullptr) {
        event->set_name(c->second->name());
        event->set_detail_info(r.name);
        find++;
      } else {
        miss++;
        event->set_name(r.name);
      }
X
Xin Pan 已提交
506 507
      event->set_start_ns(r.start_ns);
      event->set_end_ns(r.end_ns);
X
Xin Pan 已提交
508
      event->set_sub_device_id(r.stream_id);
X
Xin Pan 已提交
509 510 511
      event->set_device_id(r.device_id);
      event->mutable_memcopy()->set_bytes(r.bytes);
    }
512
    VLOG(1) << "MemRecord event miss: " << miss << " find: " << find;
X
Xin Pan 已提交
513
    std::ofstream profile_f;
514 515 516
    profile_f.open(profile_path,
                   std::ios::out | std::ios::trunc | std::ios::binary);
    profile_pb.SerializeToOstream(&profile_f);
X
Xin Pan 已提交
517
    profile_f.close();
518 519 520 521
    return profile_pb;
  }

  void Disable() {
Q
qiaolongfei 已提交
522
#ifdef PADDLE_WITH_CUPTI
523
    // flush might cause additional calls to DeviceTracker.
524 525
    CUPTI_CALL(
        dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
Q
qiaolongfei 已提交
526 527 528
#endif  // PADDLE_WITH_CUPTI
    std::lock_guard<std::mutex> l(trace_mu_);
#ifdef PADDLE_WITH_CUPTI
529
    DisableActivity();
530
    CUPTI_CALL(dynload::cuptiUnsubscribe(subscriber_));
531
    CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
Q
qiaolongfei 已提交
532
#endif  // PADDLE_WITH_CUPTI
533 534 535 536
    enabled_ = false;
  }

 private:
Q
qiaolongfei 已提交
537
#ifdef PADDLE_WITH_CUPTI
538 539 540
  static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
                                   CUpti_CallbackId cbid, const void *cbdata) {
    auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
541 542 543 544
    DeviceTracerImpl *tracer = reinterpret_cast<DeviceTracerImpl *>(userdata);
    if (cbInfo->callbackSite == CUPTI_API_ENTER) {
      Event *event = CurAnnotation();
      tracer->AddAnnotation(cbInfo->correlationId, event);
545 546 547 548
    }
  }
  CUpti_SubscriberHandle subscriber_;
#endif  // PADDLE_WITH_CUPTI
Q
qiaolongfei 已提交
549 550 551 552
  std::mutex trace_mu_;
  bool enabled_;
  uint64_t start_ns_;
  uint64_t end_ns_;
553 554 555
  std::forward_list<KernelRecord> kernel_records_;
  std::forward_list<MemRecord> mem_records_;
  std::forward_list<std::forward_list<CPURecord>> cpu_records_;
556
  std::forward_list<std::forward_list<ActiveKindRecord>> active_kind_records_;
557 558 559
  std::forward_list<std::forward_list<std::pair<uint32_t, Event *>>>
      correlations_pairs;
  std::unordered_map<uint32_t, Event *> correlations_;
560 561
};

Q
qiaolongfei 已提交
562
void CreateTracer(DeviceTracer **t) { *t = new DeviceTracerImpl(); }
563 564 565 566 567 568

DeviceTracer *GetDeviceTracer() {
  std::call_once(tracer_once_flag, CreateTracer, &tracer);
  return tracer;
}

569
void SetCurAnnotation(Event *event) { annotation_stack.push_back(event); }
X
Xin Pan 已提交
570 571 572

void ClearCurAnnotation() { annotation_stack.pop_back(); }

573 574
Event *CurAnnotation() {
  if (annotation_stack.empty()) return nullptr;
X
Xin Pan 已提交
575 576
  return annotation_stack.back();
}
577 578 579 580
std::string CurAnnotationName() {
  if (annotation_stack.empty()) return "";
  return annotation_stack.back()->name();
}
X
Xin Pan 已提交
581 582 583 584 585 586

void SetCurBlock(int block_id) { block_id_stack.push_back(block_id); }

void ClearCurBlock() { block_id_stack.pop_back(); }

int BlockDepth() { return block_id_stack.size(); }
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657

uint32_t GetCurSystemThreadId() {
  std::stringstream ss;
  ss << std::this_thread::get_id();
  uint32_t id = static_cast<uint32_t>(std::stoull(ss.str()));
  return id;
}

void RecoreCurThreadId(int32_t id) {
  auto gid = GetCurSystemThreadId();
  VLOG(1) << "RecoreCurThreadId: " << gid << " -> " << id;
  system_thread_id_map[gid] = id;
}

int32_t GetThreadIdFromSystemThreadId(uint32_t id) {
  auto it = system_thread_id_map.find(id);
  if (it != system_thread_id_map.end()) return it->second;
  // return origin id if no event is recorded in this thread.
  return static_cast<int32_t>(id);
}

#ifdef PADDLE_WITH_CUPTI
namespace {

void initCuptiCbidStr() {
  static bool called = false;
  if (called) return;
  called = true;
#define REGISTER_RUNTIME_CBID_STR(cbid) \
  runtime_cbid_str[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid

  REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFree_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020);
  REGISTER_RUNTIME_CBID_STR(
      cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000);
  REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000);
  REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050);
  REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020);
C
chengduo 已提交
658 659
  REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020);
  REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020);
660
  REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010);
661 662 663 664 665 666 667 668 669 670
#if CUDA_VERSION >= 9000
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000);
  REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000);
#endif

#undef REGISTER_RUNTIME_CBID_STR
}
}  // namespace
#endif  // PADDLE_WITH_CUPTI

671 672
}  // namespace platform
}  // namespace paddle