chrometracing_logger.cc 18.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cstdio>
#include <ctime>

#include "glog/logging.h"

#include "paddle/fluid/platform/device/gpu/gpu_info.h"
C
chenjian 已提交
21
#include "paddle/fluid/platform/enforce.h"
22 23
#include "paddle/fluid/platform/profiler/chrometracing_logger.h"
#include "paddle/fluid/platform/profiler/event_node.h"
C
chenjian 已提交
24
#include "paddle/fluid/platform/profiler/utils.h"
25 26 27 28 29 30

namespace paddle {
namespace platform {

static const char* kSchemaVersion = "1.0.0";
static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.json";
C
chenjian 已提交
31
static uint32_t span_indx = 0;
32 33 34 35 36 37 38 39

static std::string DefaultFileName() {
  auto pid = GetProcessId();
  return string_format(std::string(kDefaultFilename), pid,
                       GetStringFormatLocalTime().c_str());
}

const char* ChromeTracingLogger::categary_name_[] = {
C
chenjian 已提交
40 41 42 43
    "Operator",      "Dataloader", "ProfileStep",      "CudaRuntime",
    "Kernel",        "Memcpy",     "Memset",           "UserDefined",
    "OperatorInner", "Forward",    "Backward",         "Optimization",
    "Communication", "PythonOp",   "PythonUserDefined"};
44 45 46 47 48

void ChromeTracingLogger::OpenFile() {
  output_file_stream_.open(filename_,
                           std::ofstream::out | std::ofstream::trunc);
  if (!output_file_stream_) {
C
chenjian 已提交
49 50
    LOG(WARNING) << "Unable to open file for writing profiling data."
                 << std::endl;
51
  } else {
C
chenjian 已提交
52
    LOG(INFO) << "writing profiling data to " << filename_ << std::endl;
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
  }
}

ChromeTracingLogger::ChromeTracingLogger(const std::string& filename) {
  filename_ = filename.empty() ? DefaultFileName() : filename;
  OpenFile();
  StartLog();
}

ChromeTracingLogger::ChromeTracingLogger(const char* filename_cstr) {
  std::string filename(filename_cstr);
  filename_ = filename.empty() ? DefaultFileName() : filename;
  OpenFile();
  StartLog();
}

ChromeTracingLogger::~ChromeTracingLogger() {
  EndLog();
  output_file_stream_.close();
}

void ChromeTracingLogger::LogNodeTrees(const NodeTrees& node_trees) {
  // log all nodes except root node, root node is a helper node.
  const std::map<uint64_t, std::vector<HostTraceEventNode*>>
      thread2host_event_nodes = node_trees.Traverse(true);
  for (auto it = thread2host_event_nodes.begin();
       it != thread2host_event_nodes.end(); ++it) {
    for (auto hostnode = it->second.begin(); hostnode != it->second.end();
         ++hostnode) {
      if (hostnode != it->second.begin()) {  // skip root node
        (*hostnode)->LogMe(this);
      }
      for (auto runtimenode = (*hostnode)->GetRuntimeTraceEventNodes().begin();
           runtimenode != (*hostnode)->GetRuntimeTraceEventNodes().end();
           ++runtimenode) {
        (*runtimenode)->LogMe(this);
        for (auto devicenode =
                 (*runtimenode)->GetDeviceTraceEventNodes().begin();
             devicenode != (*runtimenode)->GetDeviceTraceEventNodes().end();
             ++devicenode) {
          (*devicenode)->LogMe(this);
        }
      }
    }
  }
}

void ChromeTracingLogger::LogHostTraceEventNode(
    const HostTraceEventNode& host_node) {
  if (!output_file_stream_) {
    return;
  }
C
chenjian 已提交
105 106 107 108 109 110 111 112 113 114 115
  switch (host_node.Type()) {
    case TracerEventType::ProfileStep:
    case TracerEventType::Forward:
    case TracerEventType::Backward:
    case TracerEventType::Dataloader:
    case TracerEventType::Optimization:
    case TracerEventType::PythonOp:
    case TracerEventType::PythonUserDefined:
      output_file_stream_ << string_format(
          std::string(
              R"JSON(
116
  { 
C
chenjian 已提交
117
    "name": "%s", "pid": %lld, "tid": "%lld(Python)",
118 119 120
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
C
chenjian 已提交
121 122
      "start_ns": %lld,
      "end_ns": %lld
123 124 125
    }
  },
  )JSON"),
C
chenjian 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
          host_node.Name().c_str(), host_node.ProcessId(), host_node.ThreadId(),
          nsToUs(host_node.StartNs()), nsToUs(host_node.Duration()),
          categary_name_[static_cast<int>(host_node.Type())],
          host_node.StartNs(), host_node.EndNs());
      break;
    default:
      output_file_stream_ << string_format(
          std::string(
              R"JSON(
  { 
    "name": "%s", "pid": %lld, "tid": "%lld(C++)",
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
      "start_ns": %lld,
      "end_ns": %lld
    }
  },
  )JSON"),
          host_node.Name().c_str(), host_node.ProcessId(), host_node.ThreadId(),
          nsToUs(host_node.StartNs()), nsToUs(host_node.Duration()),
          categary_name_[static_cast<int>(host_node.Type())],
          host_node.StartNs(), host_node.EndNs());
      break;
  }

  pid_tid_set_.insert({host_node.ProcessId(), host_node.ThreadId()});
153 154 155 156 157 158 159 160 161 162 163
}

void ChromeTracingLogger::LogRuntimeTraceEventNode(
    const CudaRuntimeTraceEventNode& runtime_node) {
  if (!output_file_stream_) {
    return;
  }
  output_file_stream_ << string_format(
      std::string(
          R"JSON(
  { 
C
chenjian 已提交
164
    "name": "%s", "pid": %lld, "tid": "%lld(C++)",
165 166 167
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
C
chenjian 已提交
168 169 170
      "correlation id": %d,
      "start_ns": %lld,
      "end_ns": %lld
171 172 173 174 175 176 177
    }
  },
  )JSON"),
      runtime_node.Name().c_str(), runtime_node.ProcessId(),
      runtime_node.ThreadId(), nsToUs(runtime_node.StartNs()),
      nsToUs(runtime_node.Duration()),
      categary_name_[static_cast<int>(runtime_node.Type())],
C
chenjian 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
      runtime_node.CorrelationId(), runtime_node.StartNs(),
      runtime_node.EndNs());
  pid_tid_set_.insert({runtime_node.ProcessId(), runtime_node.ThreadId()});

  output_file_stream_ << string_format(
      std::string(
          R"JSON(
  { 
    "name": "launch", "id": %d, "pid": %lld, "tid": "%lld(C++)",
    "ts": %lld, 
    "ph": "s", "cat": "async"
  },
  )JSON"),
      runtime_node.CorrelationId(), runtime_node.ProcessId(),
      runtime_node.ThreadId(),
      nsToUs((runtime_node.StartNs() + runtime_node.EndNs()) >> 1));
  pid_tid_set_.insert({runtime_node.ProcessId(), runtime_node.ThreadId()});
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
}

void ChromeTracingLogger::LogDeviceTraceEventNode(
    const DeviceTraceEventNode& device_node) {
  if (!output_file_stream_) {
    return;
  }
  switch (device_node.Type()) {
    case TracerEventType::Kernel:
      HandleTypeKernel(device_node);
      break;
    case TracerEventType::Memcpy:
      HandleTypeMemcpy(device_node);
      break;
    case TracerEventType::Memset:
      HandleTypeMemset(device_node);
    default:
      break;
  }
C
chenjian 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  if (nsToUs(device_node.Duration()) == 0) {
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
  { 
    "name": "launch", "id": %d, "pid": %lld, "tid": %lld,
    "ts": %lld, 
    "ph": "f", "cat": "async"
  },
  )JSON"),
        device_node.CorrelationId(), device_node.DeviceId(),
        device_node.StreamId(), nsToUs(device_node.StartNs()));
    deviceid_streamid_set_.insert(
        {device_node.DeviceId(), device_node.StreamId()});
  } else {
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
  { 
    "name": "launch", "id": %d, "pid": %lld, "tid": %lld,
    "ts": %lld, 
    "ph": "f", "cat": "async", "bp": "e"
  },
  )JSON"),
        device_node.CorrelationId(), device_node.DeviceId(),
        device_node.StreamId(),
        nsToUs((device_node.StartNs() + device_node.EndNs()) >> 1));
    deviceid_streamid_set_.insert(
        {device_node.DeviceId(), device_node.StreamId()});
  }
244 245 246 247 248 249 250 251
}

void ChromeTracingLogger::HandleTypeKernel(
    const DeviceTraceEventNode& device_node) {
  KernelEventInfo kernel_info = device_node.KernelInfo();
  float blocks_per_sm = 0.0;
  float warps_per_sm = 0.0;
  float occupancy = 0.0;
C
chenjian 已提交
252
#if defined(PADDLE_WITH_CUPTI)
253 254 255
  constexpr int threads_per_warp = 32;
  const gpuDeviceProp& device_property =
      GetDeviceProperties(device_node.DeviceId());
C
chenjian 已提交
256 257 258
  blocks_per_sm = static_cast<float>(kernel_info.grid_x * kernel_info.grid_y *
                                     kernel_info.grid_z) /
                  device_property.multiProcessorCount;
259 260 261
  warps_per_sm = blocks_per_sm * (kernel_info.block_x * kernel_info.block_y *
                                  kernel_info.block_z) /
                 threads_per_warp;
C
chenjian 已提交
262 263 264 265 266
  occupancy = CalculateEstOccupancy(
      device_node.DeviceId(), kernel_info.registers_per_thread,
      kernel_info.static_shared_memory, kernel_info.dynamic_shared_memory,
      kernel_info.block_x, kernel_info.block_y, kernel_info.block_z,
      blocks_per_sm);
267 268 269 270 271 272 273 274 275 276
#endif

  output_file_stream_ << string_format(
      std::string(
          R"JSON(
  { 
    "name": "%s", "pid": %lld, "tid": %lld,
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
C
chenjian 已提交
277 278
      "start_ns": %lld,
      "end_ns": %lld,
279 280 281
      "device": %d, "context": %d,
      "stream": %d, "correlation id": %d,
      "registers per thread": %d,
C
chenjian 已提交
282
      "shared memory": %d,
283 284 285 286
      "blocks per SM": %f,
      "warps per SM": %f,
      "grid": [%d, %d, %d],
      "block": [%d, %d, %d],
C
chenjian 已提交
287
      "theoretical achieved occupancy %%": %f
288 289 290 291 292 293 294
    }
  },
  )JSON"),
      device_node.Name().c_str(), device_node.DeviceId(),
      device_node.StreamId(), nsToUs(device_node.StartNs()),
      nsToUs(device_node.Duration()),
      categary_name_[static_cast<int>(device_node.Type())],
C
chenjian 已提交
295 296
      device_node.StartNs(), device_node.EndNs(), device_node.DeviceId(),
      device_node.ContextId(), device_node.StreamId(),
297 298 299 300
      device_node.CorrelationId(), kernel_info.registers_per_thread,
      kernel_info.static_shared_memory + kernel_info.dynamic_shared_memory,
      blocks_per_sm, warps_per_sm, kernel_info.grid_x, kernel_info.grid_y,
      kernel_info.grid_z, kernel_info.block_x, kernel_info.block_y,
C
chenjian 已提交
301
      kernel_info.block_z, occupancy * 100);
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
}

void ChromeTracingLogger::HandleTypeMemcpy(
    const DeviceTraceEventNode& device_node) {
  MemcpyEventInfo memcpy_info = device_node.MemcpyInfo();
  float memory_bandwidth = 0;
  if (device_node.Duration() > 0) {
    memory_bandwidth = memcpy_info.num_bytes * 1.0 / device_node.Duration();
  }
  output_file_stream_ << string_format(
      std::string(
          R"JSON(
  {
    "name": "%s", "pid": %lld, "tid": %lld,
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
C
chenjian 已提交
319 320
      "start_ns": %lld,
      "end_ns": %lld,
321 322 323 324 325 326 327 328 329
      "stream": %d, "correlation id": %d,
      "bytes": %d, "memory bandwidth (GB/s)": %f
    }
  },
  )JSON"),
      device_node.Name().c_str(), device_node.DeviceId(),
      device_node.StreamId(), nsToUs(device_node.StartNs()),
      nsToUs(device_node.Duration()),
      categary_name_[static_cast<int>(device_node.Type())],
C
chenjian 已提交
330 331
      device_node.StartNs(), device_node.EndNs(), device_node.StreamId(),
      device_node.CorrelationId(), memcpy_info.num_bytes, memory_bandwidth);
332 333 334 335 336 337 338 339 340 341 342 343 344
}

void ChromeTracingLogger::HandleTypeMemset(
    const DeviceTraceEventNode& device_node) {
  MemsetEventInfo memset_info = device_node.MemsetInfo();
  output_file_stream_ << string_format(
      std::string(
          R"JSON(
  {
    "name": "%s", "pid": %lld, "tid": %lld,
    "ts": %lld, "dur": %lld,
    "ph": "X", "cat": "%s", 
    "args": {
C
chenjian 已提交
345 346
      "start_ns": %lld,
      "end_ns": %lld,
347 348 349 350 351 352 353 354 355 356
      "device": %d, "context": %d,
      "stream": %d, "correlation id": %d,
      "bytes": %d, "value": %d
    }
  },
  )JSON"),
      device_node.Name().c_str(), device_node.DeviceId(),
      device_node.StreamId(), nsToUs(device_node.StartNs()),
      nsToUs(device_node.Duration()),
      categary_name_[static_cast<int>(device_node.Type())],
C
chenjian 已提交
357 358
      device_node.StartNs(), device_node.EndNs(), device_node.DeviceId(),
      device_node.ContextId(), device_node.StreamId(),
359 360 361 362 363 364 365 366
      device_node.CorrelationId(), memset_info.num_bytes, memset_info.value);
}

void ChromeTracingLogger::StartLog() {
  output_file_stream_ << string_format(std::string(
                                           R"JSON(
  { 
    "schemaVersion": "%s",
C
chenjian 已提交
367 368
    "displayTimeUnit": "ms",
    "span_indx": "%d",
369
  )JSON"),
C
chenjian 已提交
370
                                       kSchemaVersion, span_indx++);
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
// add device property information
#if defined(PADDLE_WITH_CUDA)
  output_file_stream_ << std::string(R"JSON(
    "deviceProperties": [
  )JSON");
  std::vector<int> device_ids = GetSelectedDevices();
  for (auto index = 0u; index < device_ids.size() - 1; index++) {
    const gpuDeviceProp& device_property =
        GetDeviceProperties(device_ids[index]);
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
    {
       "id": %d, "name": "%s", "totalGlobalMem": %u,
      "computeMajor": %d, "computeMinor": %d,
      "maxThreadsPerBlock": %d, "maxThreadsPerMultiprocessor": %d,
      "regsPerBlock": %d, "regsPerMultiprocessor": %d, "warpSize": %d,
      "sharedMemPerBlock": %d, "sharedMemPerMultiprocessor": %d,
      "smCount": %d, "sharedMemPerBlockOptin": %d
    },
  )JSON"),
        device_ids[index], device_property.name, device_property.totalGlobalMem,
        device_property.major, device_property.minor,
        device_property.maxThreadsPerBlock,
        device_property.maxThreadsPerMultiProcessor,
        device_property.regsPerBlock, device_property.regsPerMultiprocessor,
        device_property.warpSize, device_property.sharedMemPerBlock,
        device_property.sharedMemPerMultiprocessor,
        device_property.multiProcessorCount,
        device_property.sharedMemPerBlockOptin);
  }
  if (device_ids.size() > 0) {
    const gpuDeviceProp& device_property =
        GetDeviceProperties(device_ids[device_ids.size() - 1]);
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
    {
       "id": %d, "name": "%s", "totalGlobalMem": %u,
      "computeMajor": %d, "computeMinor": %d,
      "maxThreadsPerBlock": %d, "maxThreadsPerMultiprocessor": %d,
      "regsPerBlock": %d, "regsPerMultiprocessor": %d, "warpSize": %d,
      "sharedMemPerBlock": %d, "sharedMemPerMultiprocessor": %d,
      "smCount": %d, "sharedMemPerBlockOptin": %d
    }],
  )JSON"),
        device_ids[device_ids.size() - 1], device_property.name,
        device_property.totalGlobalMem, device_property.major,
        device_property.minor, device_property.maxThreadsPerBlock,
        device_property.maxThreadsPerMultiProcessor,
        device_property.regsPerBlock, device_property.regsPerMultiprocessor,
        device_property.warpSize, device_property.sharedMemPerBlock,
        device_property.sharedMemPerMultiprocessor,
        device_property.multiProcessorCount,
        device_property.sharedMemPerBlockOptin);
  }
#endif

  output_file_stream_ << std::string(
      R"JSON(
    "traceEvents": [
  )JSON");
}

C
chenjian 已提交
435 436 437
void ChromeTracingLogger::LogMetaInfo(
    const std::unordered_map<std::string, std::string> extra_info) {
  RefineDisplayName(extra_info);
438 439 440
  output_file_stream_ << std::string(
      R"JSON(
  {}
C
chenjian 已提交
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
  ],
  )JSON");
  output_file_stream_ << std::string(R"JSON(
  "ExtraInfo": {)JSON");
  size_t count = extra_info.size();
  for (const auto& kv : extra_info) {
    if (count > 1) {
      output_file_stream_ << string_format(std::string(R"JSON(
     "%s": "%s",
   )JSON"),
                                           kv.first.c_str(), kv.second.c_str());
    } else {
      output_file_stream_ << string_format(std::string(R"JSON(
     "%s": "%s"
   )JSON"),
                                           kv.first.c_str(), kv.second.c_str());
    }
    count--;
  }
  output_file_stream_ << std::string(R"JSON(
  })JSON");
}

void ChromeTracingLogger::RefineDisplayName(
    std::unordered_map<std::string, std::string> extra_info) {
  for (auto it = pid_tid_set_.begin(); it != pid_tid_set_.end(); ++it) {
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
  {
    "name": "process_name", "pid": %lld, "tid": "%lld(Python)",
    "ph": "M", 
    "args": {
      "name": "Process %lld (CPU)"
    }
  },
  {
    "name": "process_name", "pid": %lld, "tid": "%lld(C++)",
    "ph": "M", 
    "args": {
      "name": "Process %lld (CPU)"
    }
  },
   {
    "name": "thread_name", "pid": %lld, "tid": "%lld(Python)",
    "ph": "M", 
    "args": {
      "name": "thread %lld:%s(Python)"
    }
  },
  {
    "name": "thread_name", "pid": %lld, "tid": "%lld(C++)",
    "ph": "M", 
    "args": {
      "name": "thread %lld:%s(C++)"
    }
  },
  {
    "name": "process_sort_index", "pid": %lld, "tid": %lld,
    "ph": "M", 
    "args": {
      "sort_index": %lld
    }
  },  
  {
    "name": "thread_sort_index", "pid": %lld, "tid": "%lld(Python)",
    "ph": "M", 
    "args": {
      "sort_index": %lld
    }
  },
  {
    "name": "thread_sort_index", "pid": %lld, "tid": "%lld(C++)",
    "ph": "M", 
    "args": {
      "sort_index": %lld
    }
  },
  )JSON"),
        (*it).first, (*it).second, (*it).first, (*it).first, (*it).second,
        (*it).first, (*it).first, (*it).second, (*it).second,
        extra_info[string_format(std::string("%lld"), (*it).second)].c_str(),
        (*it).first, (*it).second, (*it).second,
        extra_info[string_format(std::string("%lld"), (*it).second)].c_str(),
        (*it).first, (*it).second, (*it).first, (*it).first, (*it).second,
        (*it).second * 2, (*it).first, (*it).second, (*it).second * 2 + 1);
  }

  for (auto it = deviceid_streamid_set_.begin();
       it != deviceid_streamid_set_.end(); ++it) {
    output_file_stream_ << string_format(
        std::string(
            R"JSON(
  {
    "name": "process_name", "pid": %lld, "tid": %lld,
    "ph": "M", 
    "args": {
      "name": "Deivce %lld (GPU)"
    }
  },
   {
    "name": "thread_name", "pid": %lld, "tid": %lld,
    "ph": "M", 
    "args": {
      "name": "stream %lld"
    }
  },
  {
    "name": "process_sort_index", "pid": %lld, "tid": %lld,
    "ph": "M", 
    "args": {
      "sort_index": %lld
    }
  },  
  {
    "name": "thread_sort_index", "pid": %lld, "tid": %lld,
    "ph": "M", 
    "args": {
      "sort_index": %lld
    }
  },  
  )JSON"),
        (*it).first, (*it).second, (*it).first, (*it).first, (*it).second,
        (*it).second, (*it).first, (*it).second, (*it).first + 0x10000000,
        (*it).first, (*it).second, (*it).second);
  }
}

void ChromeTracingLogger::EndLog() {
  output_file_stream_ << std::string(
      R"JSON(
572 573 574 575 576 577
  }
  )JSON");
}

}  // namespace platform
}  // namespace paddle