提交 f3cbfc02 编写于 作者: X Xin Pan

Add MEMCPY information

上级 55b2d3d0
...@@ -55,6 +55,36 @@ uint64_t kAlignSize = 8; ...@@ -55,6 +55,36 @@ uint64_t kAlignSize = 8;
} \ } \
} while (0) } while (0)
std::string MemcpyKind(CUpti_ActivityMemcpyKind kind) {
switch (kind) {
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
return "MEMCPY_HtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
return "MEMCPY_DtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
return "MEMCPY_HtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
return "MEMCPY_AtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
return "MEMCPY_AtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
return "MEMCPY_AtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
return "MEMCPY_DtoA";
case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
return "MEMCPY_DtoD";
case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
return "MEMCPY_HtoH";
case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
return "MEMCPY_PtoP";
case CUPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT:
return "MEMCPY_FORCE_INT";
default:
break;
}
return "MEMCPY";
}
void EnableActivity() { void EnableActivity() {
// Device activity record is created when CUDA initializes, so we // Device activity record is created when CUDA initializes, so we
// want to enable it before cuInit() or any CUDA runtime call. // want to enable it before cuInit() or any CUDA runtime call.
...@@ -111,6 +141,26 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer, ...@@ -111,6 +141,26 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
kernel->correlationId); kernel->correlationId);
break; break;
} }
case CUPTI_ACTIVITY_KIND_MEMCPY: {
auto *memcpy =
reinterpret_cast<const CUpti_ActivityMemcpy *>(record);
tracer->AddMemRecords(
MemcpyKind(
static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
memcpy->correlationId, memcpy->bytes);
break;
}
case CUPTI_ACTIVITY_KIND_MEMCPY2: {
auto *memcpy =
reinterpret_cast<const CUpti_ActivityMemcpy2 *>(record);
tracer->AddMemRecords(
MemcpyKind(
static_cast<CUpti_ActivityMemcpyKind>(memcpy->copyKind)),
memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
memcpy->correlationId, memcpy->bytes);
break;
}
default: { break; } default: { break; }
} }
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
...@@ -148,6 +198,13 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -148,6 +198,13 @@ class DeviceTracerImpl : public DeviceTracer {
std::hash<std::thread::id>{}(std::this_thread::get_id())}); std::hash<std::thread::id>{}(std::this_thread::get_id())});
} }
void AddMemRecords(const std::string &name, uint64_t start_ns,
uint64_t end_ns, uint32_t device_id, uint32_t stream_id,
uint32_t correlation_id, uint64_t bytes) {
mem_records_.push_back(MemRecord{name, start_ns, end_ns, device_id,
stream_id, correlation_id, bytes});
}
void AddKernelRecords(uint64_t start, uint64_t end, uint32_t device_id, void AddKernelRecords(uint64_t start, uint64_t end, uint32_t device_id,
uint32_t stream_id, uint32_t correlation_id) { uint32_t stream_id, uint32_t correlation_id) {
std::lock_guard<std::mutex> l(trace_mu_); std::lock_guard<std::mutex> l(trace_mu_);
...@@ -183,7 +240,6 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -183,7 +240,6 @@ class DeviceTracerImpl : public DeviceTracer {
CUPTI_CALL( CUPTI_CALL(
dynload::cuptiEnableCallback(1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, dynload::cuptiEnableCallback(1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel)); CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel));
CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_)); CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
enabled_ = true; enabled_ = true;
} }
...@@ -214,6 +270,15 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -214,6 +270,15 @@ class DeviceTracerImpl : public DeviceTracer {
event->set_stream_id(r.thread_id); event->set_stream_id(r.thread_id);
event->set_device_id(-1); event->set_device_id(-1);
} }
for (const MemRecord &r : mem_records_) {
auto *event = profile_pb.add_events();
event->set_name(r.name);
event->set_start_ns(r.start_ns);
event->set_end_ns(r.end_ns);
event->set_stream_id(r.stream_id);
event->set_device_id(r.device_id);
event->mutable_memcopy()->set_bytes(r.bytes);
}
std::string profile_str; std::string profile_str;
google::protobuf::TextFormat::PrintToString(profile_pb, &profile_str); google::protobuf::TextFormat::PrintToString(profile_pb, &profile_str);
std::ofstream profile_f; std::ofstream profile_f;
...@@ -257,6 +322,7 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -257,6 +322,7 @@ class DeviceTracerImpl : public DeviceTracer {
uint64_t start_ns_; uint64_t start_ns_;
uint64_t end_ns_; uint64_t end_ns_;
std::vector<KernelRecord> kernel_records_; std::vector<KernelRecord> kernel_records_;
std::vector<MemRecord> mem_records_;
std::vector<CPURecord> cpu_records_; std::vector<CPURecord> cpu_records_;
std::unordered_map<uint32_t, std::string> correlations_; std::unordered_map<uint32_t, std::string> correlations_;
CUpti_SubscriberHandle subscriber_; CUpti_SubscriberHandle subscriber_;
...@@ -272,6 +338,10 @@ class DeviceTracerDummy : public DeviceTracer { ...@@ -272,6 +338,10 @@ class DeviceTracerDummy : public DeviceTracer {
void AddCPURecords(const char *anno, uint64_t start_ns, uint64_t end_ns) {} void AddCPURecords(const char *anno, uint64_t start_ns, uint64_t end_ns) {}
void AddMemRecords(const std::string &name, uint64_t start_ns,
uint64_t end_ns, uint32_t device_id, uint32_t stream_id,
uint32_t correlation_id, uint64_t bytes) {}
void AddKernelRecords(uint64_t start, uint64_t end, uint32_t device_id, void AddKernelRecords(uint64_t start, uint64_t end, uint32_t device_id,
uint32_t stream_id, uint32_t correlation_id) {} uint32_t stream_id, uint32_t correlation_id) {}
......
...@@ -42,6 +42,15 @@ class DeviceTracer { ...@@ -42,6 +42,15 @@ class DeviceTracer {
uint64_t end_ns; uint64_t end_ns;
uint64_t thread_id; uint64_t thread_id;
}; };
struct MemRecord {
std::string name;
uint64_t start_ns;
uint64_t end_ns;
uint32_t device_id;
uint32_t stream_id;
uint32_t correlation_id;
uint64_t bytes;
};
virtual ~DeviceTracer() {} virtual ~DeviceTracer() {}
// Needs to be called once before use. // Needs to be called once before use.
...@@ -54,6 +63,11 @@ class DeviceTracer { ...@@ -54,6 +63,11 @@ class DeviceTracer {
// human-readable annotations. // human-readable annotations.
virtual void AddAnnotation(uint64_t id, const std::string& anno) = 0; virtual void AddAnnotation(uint64_t id, const std::string& anno) = 0;
virtual void AddMemRecords(const std::string& name, uint64_t start_ns,
uint64_t end_ns, uint32_t device_id,
uint32_t stream_id, uint32_t correlation_id,
uint64_t bytes) = 0;
virtual void AddCPURecords(const char* anno, uint64_t start_ns, virtual void AddCPURecords(const char* anno, uint64_t start_ns,
uint64_t end_ns) = 0; uint64_t end_ns) = 0;
......
...@@ -74,7 +74,8 @@ extern void *cupti_dso_handle; ...@@ -74,7 +74,8 @@ extern void *cupti_dso_handle;
__macro(cuptiFinalize); \ __macro(cuptiFinalize); \
__macro(cuptiSubscribe); \ __macro(cuptiSubscribe); \
__macro(cuptiUnsubscribe); \ __macro(cuptiUnsubscribe); \
__macro(cuptiEnableCallback); __macro(cuptiEnableCallback); \
__macro(cuptiEnableDomain);
CUPTI_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUPTI_WRAP); CUPTI_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUPTI_WRAP);
......
...@@ -15,13 +15,17 @@ limitations under the License. */ ...@@ -15,13 +15,17 @@ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
package paddle.platform.proto; package paddle.platform.proto;
message MemCopy { optional uint64 bytes = 3; }
message Event { message Event {
optional string name = 1; optional string name = 1;
optional uint64 start_ns = 2; optional uint64 start_ns = 2;
optional uint64 end_ns = 3; optional uint64 end_ns = 3;
// When positive, it represents gpu id. When -1, it represents CPU. // When positive, it represents gpu id. When -1, it represents CPU.
optional int32 device_id = 5; optional int64 device_id = 5;
optional uint32 stream_id = 6; optional uint32 stream_id = 6;
optional MemCopy memcopy = 7;
} }
message Profile { message Profile {
......
...@@ -135,6 +135,8 @@ class Timeline(object): ...@@ -135,6 +135,8 @@ class Timeline(object):
for event in self._profile_pb.events: for event in self._profile_pb.events:
pid = self._devices[event.device_id] pid = self._devices[event.device_id]
args = {'name': event.name} args = {'name': event.name}
if event.memcopy.bytes > 0:
args = {'mem_bytes': event.memcopy.bytes}
# TODO(panyx0718): Chrome tracing only handles ms. However, some # TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here. # ops takes micro-seconds. Hence, we keep the ns here.
self._chrome_trace.emit_region(event.start_ns, self._chrome_trace.emit_region(event.start_ns,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册