未验证 提交 360b8383 编写于 作者: C chenjian 提交者: GitHub

add serialization for new field in event node (#43405)

* add serialization for new field in event node

* fix a bug
上级 30b10630
...@@ -92,6 +92,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() { ...@@ -92,6 +92,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
device_node); // insert into runtime_node device_node); // insert into runtime_node
} }
} }
// handle mem node
for (int mem_node_index = 0;
mem_node_index < host_node_proto.mem_nodes_size();
mem_node_index++) {
const MemTraceEventNodeProto& mem_node_proto =
host_node_proto.mem_nodes(mem_node_index);
MemTraceEventNode* mem_node = RestoreMemTraceEventNode(mem_node_proto);
host_node->AddMemNode(mem_node);
}
// handle op supplement node
for (int op_supplement_node_index = 0;
op_supplement_node_index <
host_node_proto.op_supplement_nodes_size();
op_supplement_node_index++) {
const OperatorSupplementEventNodeProto& op_supplement_node_proto =
host_node_proto.op_supplement_nodes(op_supplement_node_index);
OperatorSupplementEventNode* op_supplement_node =
RestoreOperatorSupplementEventNode(op_supplement_node_proto);
host_node->SetOperatorSupplementNode(op_supplement_node);
}
} }
// restore parent-child relationship // restore parent-child relationship
for (auto it = child_parent_map.begin(); it != child_parent_map.end(); for (auto it = child_parent_map.begin(); it != child_parent_map.end();
...@@ -176,6 +196,62 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode( ...@@ -176,6 +196,62 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode(
return new HostTraceEventNode(host_event); return new HostTraceEventNode(host_event);
} }
MemTraceEventNode* DeserializationReader::RestoreMemTraceEventNode(
const MemTraceEventNodeProto& mem_node_proto) {
const MemTraceEventProto& mem_event_proto = mem_node_proto.mem_event();
MemTraceEvent mem_event;
mem_event.timestamp_ns = mem_event_proto.timestamp_ns();
mem_event.addr = mem_event_proto.addr();
mem_event.type = static_cast<TracerMemEventType>(mem_event_proto.type());
mem_event.process_id = mem_event_proto.process_id();
mem_event.thread_id = mem_event_proto.thread_id();
mem_event.increase_bytes = mem_event_proto.increase_bytes();
mem_event.place = mem_event_proto.place();
mem_event.current_allocated = mem_event_proto.current_allocated();
mem_event.current_reserved = mem_event_proto.current_reserved();
return new MemTraceEventNode(mem_event);
}
OperatorSupplementEventNode*
DeserializationReader::RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto& op_supplement_node_proto) {
const OperatorSupplementEventProto& op_supplement_event_proto =
op_supplement_node_proto.op_supplement_event();
OperatorSupplementEvent op_supplement_event;
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
auto input_shape_proto = op_supplement_event_proto.input_shapes();
for (int i = 0; i < input_shape_proto.key_size(); i++) {
auto input_shape_vec = input_shapes[input_shape_proto.key(i)];
auto shape_vectors_proto = input_shape_proto.shape_vecs(i);
for (int j = 0; j < shape_vectors_proto.shapes_size(); j++) {
auto shape_vector_proto = shape_vectors_proto.shapes(j);
std::vector<int64_t> shape;
for (int k = 0; k < shape_vector_proto.size_size(); k++) {
shape.push_back(shape_vector_proto.size(k));
}
input_shape_vec.push_back(shape);
}
}
op_supplement_event.input_shapes = input_shapes;
auto dtype_proto = op_supplement_event_proto.dtypes();
for (int i = 0; i < dtype_proto.key_size(); i++) {
auto dtype_vec = dtypes[dtype_proto.key(i)];
auto dtype_vec_proto = dtype_proto.dtype_vecs(i);
for (int j = 0; j < dtype_vec_proto.dtype_size(); j++) {
auto dtype_string = dtype_vec_proto.dtype(j);
dtype_vec.push_back(dtype_string);
}
}
op_supplement_event.dtypes = dtypes;
return new OperatorSupplementEventNode(op_supplement_event);
}
KernelEventInfo DeserializationReader::HandleKernelEventInfoProto( KernelEventInfo DeserializationReader::HandleKernelEventInfoProto(
const DeviceTraceEventProto& device_event_proto) { const DeviceTraceEventProto& device_event_proto) {
const KernelEventInfoProto& kernel_info_proto = const KernelEventInfoProto& kernel_info_proto =
......
...@@ -36,6 +36,9 @@ class DeserializationReader { ...@@ -36,6 +36,9 @@ class DeserializationReader {
KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&); KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&);
MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&); MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&);
MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&); MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&);
MemTraceEventNode* RestoreMemTraceEventNode(const MemTraceEventNodeProto&);
OperatorSupplementEventNode* RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto&);
std::string filename_; std::string filename_;
std::ifstream input_file_stream_; std::ifstream input_file_stream_;
NodeTreesProto* node_trees_proto_; NodeTreesProto* node_trees_proto_;
......
...@@ -46,6 +46,15 @@ enum TracerEventTypeProto { ...@@ -46,6 +46,15 @@ enum TracerEventTypeProto {
PythonOp = 13; PythonOp = 13;
// Used to mark python level userdefined // Used to mark python level userdefined
PythonUserDefined = 14; PythonUserDefined = 14;
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15;
};
enum TracerMemEventTypeProto {
// Used to mark memory allocation
Allocate = 0;
// Used to mark memory free
Free = 1;
}; };
message KernelEventInfoProto { message KernelEventInfoProto {
...@@ -121,6 +130,58 @@ message HostTraceEventProto { ...@@ -121,6 +130,58 @@ message HostTraceEventProto {
required uint64 thread_id = 6; required uint64 thread_id = 6;
} }
message MemTraceEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// memory manipulation type
required TracerMemEventTypeProto type = 2;
// memory addr of allocation or free
required uint64 addr = 3;
// process id of the record
required uint64 process_id = 4;
// thread id of the record
required uint64 thread_id = 5;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
required int64 increase_bytes = 6;
// place
required string place = 7;
// current total allocated memory
required uint64 current_allocated = 8;
// current total reserved memory
required uint64 current_reserved = 9;
}
message OperatorSupplementEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// op type name
required string op_type = 2;
// process id of the record
required uint64 process_id = 3;
// thread id of the record
required uint64 thread_id = 4;
// input shapes
message input_shape_proto {
repeated string key = 1;
message shape_vector {
message shape { repeated uint64 size = 1; }
repeated shape shapes = 1;
}
repeated shape_vector shape_vecs = 2;
}
required input_shape_proto input_shapes = 5;
// dtypes
message dtype_proto {
repeated string key = 1;
message dtype_vector { repeated string dtype = 1; }
repeated dtype_vector dtype_vecs = 2;
}
required dtype_proto dtypes = 6;
// call stack
required string callstack = 7;
}
message CudaRuntimeTraceEventProto { message CudaRuntimeTraceEventProto {
// record name // record name
required string name = 1; required string name = 1;
...@@ -166,6 +227,12 @@ message DeviceTraceEventProto { ...@@ -166,6 +227,12 @@ message DeviceTraceEventProto {
} }
} }
message OperatorSupplementEventNodeProto {
required OperatorSupplementEventProto op_supplement_event = 1;
}
message MemTraceEventNodeProto { required MemTraceEventProto mem_event = 1; }
message DeviceTraceEventNodeProto { message DeviceTraceEventNodeProto {
required DeviceTraceEventProto device_event = 1; required DeviceTraceEventProto device_event = 1;
} }
...@@ -180,6 +247,9 @@ message HostTraceEventNodeProto { ...@@ -180,6 +247,9 @@ message HostTraceEventNodeProto {
required int64 parentid = 2; required int64 parentid = 2;
required HostTraceEventProto host_trace_event = 3; required HostTraceEventProto host_trace_event = 3;
repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4; repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4;
// below is added in version 1.0.1
repeated MemTraceEventNodeProto mem_nodes = 5;
repeated OperatorSupplementEventNodeProto op_supplement_nodes = 6;
} }
message ThreadNodeTreeProto { message ThreadNodeTreeProto {
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace platform { namespace platform {
static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb"; static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb";
static const char* version = "1.0.0"; static const char* version = "1.0.1";
static uint32_t span_indx = 0; static uint32_t span_indx = 0;
static std::string DefaultFileName() { static std::string DefaultFileName() {
...@@ -106,10 +106,33 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) { ...@@ -106,10 +106,33 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this); // fill detail information (*devicenode)->LogMe(this); // fill detail information
} }
} }
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end(); ++memnode) {
MemTraceEventNodeProto* mem_node_proto =
current_host_trace_event_node_proto_->add_mem_nodes();
current_mem_trace_event_node_proto_ = mem_node_proto;
(*memnode)->LogMe(this);
}
} }
} }
} }
void SerializationLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
MemTraceEventProto* mem_trace_event = new MemTraceEventProto();
mem_trace_event->set_timestamp_ns(mem_node.TimeStampNs());
mem_trace_event->set_type(
static_cast<TracerMemEventTypeProto>(mem_node.Type()));
mem_trace_event->set_addr(mem_node.Addr());
mem_trace_event->set_process_id(mem_node.ProcessId());
mem_trace_event->set_thread_id(mem_node.ThreadId());
mem_trace_event->set_increase_bytes(mem_node.IncreaseBytes());
mem_trace_event->set_place(mem_node.Place());
mem_trace_event->set_current_allocated(mem_node.CurrentAllocated());
mem_trace_event->set_current_reserved(mem_node.CurrentReserved());
current_mem_trace_event_node_proto_->set_allocated_mem_event(mem_trace_event);
}
void SerializationLogger::LogHostTraceEventNode( void SerializationLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) { const HostTraceEventNode& host_node) {
HostTraceEventProto* host_trace_event = new HostTraceEventProto(); HostTraceEventProto* host_trace_event = new HostTraceEventProto();
...@@ -122,6 +145,59 @@ void SerializationLogger::LogHostTraceEventNode( ...@@ -122,6 +145,59 @@ void SerializationLogger::LogHostTraceEventNode(
host_trace_event->set_thread_id(host_node.ThreadId()); host_trace_event->set_thread_id(host_node.ThreadId());
current_host_trace_event_node_proto_->set_allocated_host_trace_event( current_host_trace_event_node_proto_->set_allocated_host_trace_event(
host_trace_event); host_trace_event);
OperatorSupplementEventNode* op_supplement_event_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_event_node != nullptr) {
current_op_supplement_event_node_proto_ =
current_host_trace_event_node_proto_->add_op_supplement_nodes();
OperatorSupplementEventProto* op_supplement_event_proto =
new OperatorSupplementEventProto();
op_supplement_event_proto->set_op_type(op_supplement_event_node->Name());
op_supplement_event_proto->set_timestamp_ns(
op_supplement_event_node->TimeStampNs());
op_supplement_event_proto->set_process_id(
op_supplement_event_node->ProcessId());
op_supplement_event_proto->set_thread_id(
op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack());
OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes();
for (auto it = op_supplement_event_node->InputShapes().begin();
it != op_supplement_event_node->InputShapes().end(); it++) {
input_shape_proto->add_key(it->first);
OperatorSupplementEventProto::input_shape_proto::shape_vector*
shape_vectors_proto = input_shape_proto->add_shape_vecs();
auto shape_vectors = it->second;
for (auto shape_vecs_it = shape_vectors.begin();
shape_vecs_it != shape_vectors.end(); shape_vecs_it++) {
auto shape_vector = *shape_vecs_it;
OperatorSupplementEventProto::input_shape_proto::shape_vector::shape*
shape_proto = shape_vectors_proto->add_shapes();
for (auto shape_it = shape_vector.begin();
shape_it != shape_vector.end(); shape_it++) {
shape_proto->add_size(*shape_it);
}
}
}
OperatorSupplementEventProto::dtype_proto* dtype_proto =
op_supplement_event_proto->mutable_dtypes();
for (auto it = op_supplement_event_node->Dtypes().begin();
it != op_supplement_event_node->Dtypes().end(); it++) {
dtype_proto->add_key(it->first);
OperatorSupplementEventProto::dtype_proto::dtype_vector*
dtype_vector_proto = dtype_proto->add_dtype_vecs();
auto dtype_vector = it->second;
for (auto dtype_it = dtype_vector.begin(); dtype_it != dtype_vector.end();
dtype_it++) {
dtype_vector_proto->add_dtype(*dtype_it);
}
}
current_op_supplement_event_node_proto_->set_allocated_op_supplement_event(
op_supplement_event_proto);
}
} }
void SerializationLogger::LogRuntimeTraceEventNode( void SerializationLogger::LogRuntimeTraceEventNode(
......
...@@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger { ...@@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override; void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override; void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>); void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;
private: private:
void OpenFile(); void OpenFile();
...@@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger { ...@@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger {
HostTraceEventNodeProto* current_host_trace_event_node_proto_; HostTraceEventNodeProto* current_host_trace_event_node_proto_;
CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_; CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_;
DeviceTraceEventNodeProto* current_device_trace_event_node_proto_; DeviceTraceEventNodeProto* current_device_trace_event_node_proto_;
MemTraceEventNodeProto* current_mem_trace_event_node_proto_;
OperatorSupplementEventNodeProto* current_op_supplement_event_node_proto_;
}; };
} // namespace platform } // namespace platform
......
...@@ -34,6 +34,7 @@ using paddle::platform::ProfilerResult; ...@@ -34,6 +34,7 @@ using paddle::platform::ProfilerResult;
using paddle::platform::RuntimeTraceEvent; using paddle::platform::RuntimeTraceEvent;
using paddle::platform::SerializationLogger; using paddle::platform::SerializationLogger;
using paddle::platform::TracerEventType; using paddle::platform::TracerEventType;
using paddle::platform::TracerMemEventType;
TEST(SerializationLoggerTest, dump_case0) { TEST(SerializationLoggerTest, dump_case0) {
std::list<HostTraceEvent> host_events; std::list<HostTraceEvent> host_events;
...@@ -50,6 +51,19 @@ TEST(SerializationLoggerTest, dump_case0) { ...@@ -50,6 +51,19 @@ TEST(SerializationLoggerTest, dump_case0) {
std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10)); std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10));
host_events.push_back(HostTraceEvent( host_events.push_back(HostTraceEvent(
std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11)); std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11));
mem_events.push_back(MemTraceEvent(11500, 0x1000,
TracerMemEventType::Allocate, 10, 10, 50,
"GPU:0", 50, 50));
mem_events.push_back(MemTraceEvent(11900, 0x1000, TracerMemEventType::Free,
10, 10, -50, "GPU:0", 0, 50));
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
input_shapes[std::string("X")].push_back(std::vector<int64_t>{1, 2, 3});
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32"));
op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000,
17000, 10, 10, 1, 0)); 17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000,
...@@ -91,6 +105,8 @@ TEST(SerializationLoggerTest, dump_case0) { ...@@ -91,6 +105,8 @@ TEST(SerializationLoggerTest, dump_case0) {
if ((*it)->Name() == "op1") { if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u); EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
} }
} }
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
...@@ -100,6 +116,7 @@ TEST(SerializationLoggerTest, dump_case0) { ...@@ -100,6 +116,7 @@ TEST(SerializationLoggerTest, dump_case0) {
} }
} }
tree.LogMe(&logger); tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
} }
TEST(SerializationLoggerTest, dump_case1) { TEST(SerializationLoggerTest, dump_case1) {
...@@ -154,6 +171,7 @@ TEST(SerializationLoggerTest, dump_case1) { ...@@ -154,6 +171,7 @@ TEST(SerializationLoggerTest, dump_case1) {
} }
} }
tree.LogMe(&logger); tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
} }
TEST(DeserializationReaderTest, restore_case0) { TEST(DeserializationReaderTest, restore_case0) {
...@@ -173,6 +191,8 @@ TEST(DeserializationReaderTest, restore_case0) { ...@@ -173,6 +191,8 @@ TEST(DeserializationReaderTest, restore_case0) {
if ((*it)->Name() == "op1") { if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u); EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
} }
} }
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
......
...@@ -32,6 +32,9 @@ HostPythonNode::~HostPythonNode() { ...@@ -32,6 +32,9 @@ HostPythonNode::~HostPythonNode() {
for (auto it = device_node_ptrs.begin(); it != device_node_ptrs.end(); ++it) { for (auto it = device_node_ptrs.begin(); it != device_node_ptrs.end(); ++it) {
delete *it; delete *it;
} }
for (auto it = mem_node_ptrs.begin(); it != mem_node_ptrs.end(); ++it) {
delete *it;
}
} }
HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
...@@ -77,6 +80,29 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { ...@@ -77,6 +80,29 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
runtime_python_node->device_node_ptrs.push_back(device_python_node); runtime_python_node->device_node_ptrs.push_back(device_python_node);
} }
} }
// copy MemTraceEventNode
for (auto memnode = root->GetMemTraceEventNodes().begin();
memnode != root->GetMemTraceEventNodes().end(); memnode++) {
MemPythonNode* mem_python_node = new MemPythonNode();
mem_python_node->timestamp_ns = (*memnode)->TimeStampNs();
mem_python_node->addr = (*memnode)->Addr();
mem_python_node->type = (*memnode)->Type();
mem_python_node->process_id = (*memnode)->ProcessId();
mem_python_node->thread_id = (*memnode)->ThreadId();
mem_python_node->increase_bytes = (*memnode)->IncreaseBytes();
mem_python_node->place = (*memnode)->Place();
mem_python_node->current_allocated = (*memnode)->CurrentAllocated();
mem_python_node->current_reserved = (*memnode)->CurrentReserved();
host_python_node->mem_node_ptrs.push_back(mem_python_node);
}
// copy OperatorSupplementEventNode's information if exists
OperatorSupplementEventNode* op_supplement_node =
root->GetOperatorSupplementEventNode();
if (op_supplement_node != nullptr) {
host_python_node->input_shapes = op_supplement_node->InputShapes();
host_python_node->dtypes = op_supplement_node->Dtypes();
host_python_node->callstack = op_supplement_node->CallStack();
}
return host_python_node; return host_python_node;
} }
......
...@@ -43,6 +43,31 @@ struct DevicePythonNode { ...@@ -43,6 +43,31 @@ struct DevicePythonNode {
uint64_t stream_id; uint64_t stream_id;
}; };
struct MemPythonNode {
MemPythonNode() = default;
~MemPythonNode() {}
// timestamp of the record
uint64_t timestamp_ns;
// memory addr of allocation or free
uint64_t addr;
// memory manipulation type
TracerMemEventType type;
// process id of the record
uint64_t process_id;
// thread id of the record
uint64_t thread_id;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
int64_t increase_bytes;
// place
std::string place;
// current total allocated memory
uint64_t current_allocated;
// current total reserved memory
uint64_t current_reserved;
};
struct HostPythonNode { struct HostPythonNode {
HostPythonNode() = default; HostPythonNode() = default;
~HostPythonNode(); ~HostPythonNode();
...@@ -58,12 +83,19 @@ struct HostPythonNode { ...@@ -58,12 +83,19 @@ struct HostPythonNode {
uint64_t process_id; uint64_t process_id;
// thread id of the record // thread id of the record
uint64_t thread_id; uint64_t thread_id;
// input shapes
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
// call stack
std::string callstack;
// children node // children node
std::vector<HostPythonNode*> children_node_ptrs; std::vector<HostPythonNode*> children_node_ptrs;
// runtime node // runtime node
std::vector<HostPythonNode*> runtime_node_ptrs; std::vector<HostPythonNode*> runtime_node_ptrs;
// device node // device node
std::vector<DevicePythonNode*> device_node_ptrs; std::vector<DevicePythonNode*> device_node_ptrs;
// mem node
std::vector<MemPythonNode*> mem_node_ptrs;
}; };
class ProfilerResult { class ProfilerResult {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册