未验证 提交 360b8383 编写于 作者: C chenjian 提交者: GitHub

add serialization for new field in event node (#43405)

* add serialization for new field in event node

* fix a bug
上级 30b10630
......@@ -92,6 +92,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
device_node); // insert into runtime_node
}
}
// handle mem node
for (int mem_node_index = 0;
mem_node_index < host_node_proto.mem_nodes_size();
mem_node_index++) {
const MemTraceEventNodeProto& mem_node_proto =
host_node_proto.mem_nodes(mem_node_index);
MemTraceEventNode* mem_node = RestoreMemTraceEventNode(mem_node_proto);
host_node->AddMemNode(mem_node);
}
// handle op supplement node
for (int op_supplement_node_index = 0;
op_supplement_node_index <
host_node_proto.op_supplement_nodes_size();
op_supplement_node_index++) {
const OperatorSupplementEventNodeProto& op_supplement_node_proto =
host_node_proto.op_supplement_nodes(op_supplement_node_index);
OperatorSupplementEventNode* op_supplement_node =
RestoreOperatorSupplementEventNode(op_supplement_node_proto);
host_node->SetOperatorSupplementNode(op_supplement_node);
}
}
// restore parent-child relationship
for (auto it = child_parent_map.begin(); it != child_parent_map.end();
......@@ -176,6 +196,62 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode(
return new HostTraceEventNode(host_event);
}
MemTraceEventNode* DeserializationReader::RestoreMemTraceEventNode(
const MemTraceEventNodeProto& mem_node_proto) {
const MemTraceEventProto& mem_event_proto = mem_node_proto.mem_event();
MemTraceEvent mem_event;
mem_event.timestamp_ns = mem_event_proto.timestamp_ns();
mem_event.addr = mem_event_proto.addr();
mem_event.type = static_cast<TracerMemEventType>(mem_event_proto.type());
mem_event.process_id = mem_event_proto.process_id();
mem_event.thread_id = mem_event_proto.thread_id();
mem_event.increase_bytes = mem_event_proto.increase_bytes();
mem_event.place = mem_event_proto.place();
mem_event.current_allocated = mem_event_proto.current_allocated();
mem_event.current_reserved = mem_event_proto.current_reserved();
return new MemTraceEventNode(mem_event);
}
OperatorSupplementEventNode*
DeserializationReader::RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto& op_supplement_node_proto) {
const OperatorSupplementEventProto& op_supplement_event_proto =
op_supplement_node_proto.op_supplement_event();
OperatorSupplementEvent op_supplement_event;
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
auto input_shape_proto = op_supplement_event_proto.input_shapes();
for (int i = 0; i < input_shape_proto.key_size(); i++) {
auto input_shape_vec = input_shapes[input_shape_proto.key(i)];
auto shape_vectors_proto = input_shape_proto.shape_vecs(i);
for (int j = 0; j < shape_vectors_proto.shapes_size(); j++) {
auto shape_vector_proto = shape_vectors_proto.shapes(j);
std::vector<int64_t> shape;
for (int k = 0; k < shape_vector_proto.size_size(); k++) {
shape.push_back(shape_vector_proto.size(k));
}
input_shape_vec.push_back(shape);
}
}
op_supplement_event.input_shapes = input_shapes;
auto dtype_proto = op_supplement_event_proto.dtypes();
for (int i = 0; i < dtype_proto.key_size(); i++) {
auto dtype_vec = dtypes[dtype_proto.key(i)];
auto dtype_vec_proto = dtype_proto.dtype_vecs(i);
for (int j = 0; j < dtype_vec_proto.dtype_size(); j++) {
auto dtype_string = dtype_vec_proto.dtype(j);
dtype_vec.push_back(dtype_string);
}
}
op_supplement_event.dtypes = dtypes;
return new OperatorSupplementEventNode(op_supplement_event);
}
KernelEventInfo DeserializationReader::HandleKernelEventInfoProto(
const DeviceTraceEventProto& device_event_proto) {
const KernelEventInfoProto& kernel_info_proto =
......
......@@ -36,6 +36,9 @@ class DeserializationReader {
KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&);
MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&);
MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&);
MemTraceEventNode* RestoreMemTraceEventNode(const MemTraceEventNodeProto&);
OperatorSupplementEventNode* RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto&);
std::string filename_;
std::ifstream input_file_stream_;
NodeTreesProto* node_trees_proto_;
......
......@@ -46,6 +46,15 @@ enum TracerEventTypeProto {
PythonOp = 13;
// Used to mark python level userdefined
PythonUserDefined = 14;
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15;
};
enum TracerMemEventTypeProto {
// Used to mark memory allocation
Allocate = 0;
// Used to mark memory free
Free = 1;
};
message KernelEventInfoProto {
......@@ -121,6 +130,58 @@ message HostTraceEventProto {
required uint64 thread_id = 6;
}
message MemTraceEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// memory manipulation type
required TracerMemEventTypeProto type = 2;
// memory addr of allocation or free
required uint64 addr = 3;
// process id of the record
required uint64 process_id = 4;
// thread id of the record
required uint64 thread_id = 5;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
required int64 increase_bytes = 6;
// place
required string place = 7;
// current total allocated memory
required uint64 current_allocated = 8;
// current total reserved memory
required uint64 current_reserved = 9;
}
message OperatorSupplementEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// op type name
required string op_type = 2;
// process id of the record
required uint64 process_id = 3;
// thread id of the record
required uint64 thread_id = 4;
// input shapes
message input_shape_proto {
repeated string key = 1;
message shape_vector {
message shape { repeated uint64 size = 1; }
repeated shape shapes = 1;
}
repeated shape_vector shape_vecs = 2;
}
required input_shape_proto input_shapes = 5;
// dtypes
message dtype_proto {
repeated string key = 1;
message dtype_vector { repeated string dtype = 1; }
repeated dtype_vector dtype_vecs = 2;
}
required dtype_proto dtypes = 6;
// call stack
required string callstack = 7;
}
message CudaRuntimeTraceEventProto {
// record name
required string name = 1;
......@@ -166,6 +227,12 @@ message DeviceTraceEventProto {
}
}
message OperatorSupplementEventNodeProto {
required OperatorSupplementEventProto op_supplement_event = 1;
}
message MemTraceEventNodeProto { required MemTraceEventProto mem_event = 1; }
message DeviceTraceEventNodeProto {
required DeviceTraceEventProto device_event = 1;
}
......@@ -180,6 +247,9 @@ message HostTraceEventNodeProto {
required int64 parentid = 2;
required HostTraceEventProto host_trace_event = 3;
repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4;
// below is added in version 1.0.1
repeated MemTraceEventNodeProto mem_nodes = 5;
repeated OperatorSupplementEventNodeProto op_supplement_nodes = 6;
}
message ThreadNodeTreeProto {
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace platform {
static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb";
static const char* version = "1.0.0";
static const char* version = "1.0.1";
static uint32_t span_indx = 0;
static std::string DefaultFileName() {
......@@ -106,10 +106,33 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this); // fill detail information
}
}
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end(); ++memnode) {
MemTraceEventNodeProto* mem_node_proto =
current_host_trace_event_node_proto_->add_mem_nodes();
current_mem_trace_event_node_proto_ = mem_node_proto;
(*memnode)->LogMe(this);
}
}
}
}
void SerializationLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
MemTraceEventProto* mem_trace_event = new MemTraceEventProto();
mem_trace_event->set_timestamp_ns(mem_node.TimeStampNs());
mem_trace_event->set_type(
static_cast<TracerMemEventTypeProto>(mem_node.Type()));
mem_trace_event->set_addr(mem_node.Addr());
mem_trace_event->set_process_id(mem_node.ProcessId());
mem_trace_event->set_thread_id(mem_node.ThreadId());
mem_trace_event->set_increase_bytes(mem_node.IncreaseBytes());
mem_trace_event->set_place(mem_node.Place());
mem_trace_event->set_current_allocated(mem_node.CurrentAllocated());
mem_trace_event->set_current_reserved(mem_node.CurrentReserved());
current_mem_trace_event_node_proto_->set_allocated_mem_event(mem_trace_event);
}
void SerializationLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) {
HostTraceEventProto* host_trace_event = new HostTraceEventProto();
......@@ -122,6 +145,59 @@ void SerializationLogger::LogHostTraceEventNode(
host_trace_event->set_thread_id(host_node.ThreadId());
current_host_trace_event_node_proto_->set_allocated_host_trace_event(
host_trace_event);
OperatorSupplementEventNode* op_supplement_event_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_event_node != nullptr) {
current_op_supplement_event_node_proto_ =
current_host_trace_event_node_proto_->add_op_supplement_nodes();
OperatorSupplementEventProto* op_supplement_event_proto =
new OperatorSupplementEventProto();
op_supplement_event_proto->set_op_type(op_supplement_event_node->Name());
op_supplement_event_proto->set_timestamp_ns(
op_supplement_event_node->TimeStampNs());
op_supplement_event_proto->set_process_id(
op_supplement_event_node->ProcessId());
op_supplement_event_proto->set_thread_id(
op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack());
OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes();
for (auto it = op_supplement_event_node->InputShapes().begin();
it != op_supplement_event_node->InputShapes().end(); it++) {
input_shape_proto->add_key(it->first);
OperatorSupplementEventProto::input_shape_proto::shape_vector*
shape_vectors_proto = input_shape_proto->add_shape_vecs();
auto shape_vectors = it->second;
for (auto shape_vecs_it = shape_vectors.begin();
shape_vecs_it != shape_vectors.end(); shape_vecs_it++) {
auto shape_vector = *shape_vecs_it;
OperatorSupplementEventProto::input_shape_proto::shape_vector::shape*
shape_proto = shape_vectors_proto->add_shapes();
for (auto shape_it = shape_vector.begin();
shape_it != shape_vector.end(); shape_it++) {
shape_proto->add_size(*shape_it);
}
}
}
OperatorSupplementEventProto::dtype_proto* dtype_proto =
op_supplement_event_proto->mutable_dtypes();
for (auto it = op_supplement_event_node->Dtypes().begin();
it != op_supplement_event_node->Dtypes().end(); it++) {
dtype_proto->add_key(it->first);
OperatorSupplementEventProto::dtype_proto::dtype_vector*
dtype_vector_proto = dtype_proto->add_dtype_vecs();
auto dtype_vector = it->second;
for (auto dtype_it = dtype_vector.begin(); dtype_it != dtype_vector.end();
dtype_it++) {
dtype_vector_proto->add_dtype(*dtype_it);
}
}
current_op_supplement_event_node_proto_->set_allocated_op_supplement_event(
op_supplement_event_proto);
}
}
void SerializationLogger::LogRuntimeTraceEventNode(
......
......@@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;
private:
void OpenFile();
......@@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger {
HostTraceEventNodeProto* current_host_trace_event_node_proto_;
CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_;
DeviceTraceEventNodeProto* current_device_trace_event_node_proto_;
MemTraceEventNodeProto* current_mem_trace_event_node_proto_;
OperatorSupplementEventNodeProto* current_op_supplement_event_node_proto_;
};
} // namespace platform
......
......@@ -34,6 +34,7 @@ using paddle::platform::ProfilerResult;
using paddle::platform::RuntimeTraceEvent;
using paddle::platform::SerializationLogger;
using paddle::platform::TracerEventType;
using paddle::platform::TracerMemEventType;
TEST(SerializationLoggerTest, dump_case0) {
std::list<HostTraceEvent> host_events;
......@@ -50,6 +51,19 @@ TEST(SerializationLoggerTest, dump_case0) {
std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10));
host_events.push_back(HostTraceEvent(
std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11));
mem_events.push_back(MemTraceEvent(11500, 0x1000,
TracerMemEventType::Allocate, 10, 10, 50,
"GPU:0", 50, 50));
mem_events.push_back(MemTraceEvent(11900, 0x1000, TracerMemEventType::Free,
10, 10, -50, "GPU:0", 0, 50));
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
input_shapes[std::string("X")].push_back(std::vector<int64_t>{1, 2, 3});
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32"));
op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000,
17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000,
......@@ -91,6 +105,8 @@ TEST(SerializationLoggerTest, dump_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
......@@ -100,6 +116,7 @@ TEST(SerializationLoggerTest, dump_case0) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}
TEST(SerializationLoggerTest, dump_case1) {
......@@ -154,6 +171,7 @@ TEST(SerializationLoggerTest, dump_case1) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}
TEST(DeserializationReaderTest, restore_case0) {
......@@ -173,6 +191,8 @@ TEST(DeserializationReaderTest, restore_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
......
......@@ -32,6 +32,9 @@ HostPythonNode::~HostPythonNode() {
for (auto it = device_node_ptrs.begin(); it != device_node_ptrs.end(); ++it) {
delete *it;
}
for (auto it = mem_node_ptrs.begin(); it != mem_node_ptrs.end(); ++it) {
delete *it;
}
}
HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
......@@ -77,6 +80,29 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
runtime_python_node->device_node_ptrs.push_back(device_python_node);
}
}
// copy MemTraceEventNode
for (auto memnode = root->GetMemTraceEventNodes().begin();
memnode != root->GetMemTraceEventNodes().end(); memnode++) {
MemPythonNode* mem_python_node = new MemPythonNode();
mem_python_node->timestamp_ns = (*memnode)->TimeStampNs();
mem_python_node->addr = (*memnode)->Addr();
mem_python_node->type = (*memnode)->Type();
mem_python_node->process_id = (*memnode)->ProcessId();
mem_python_node->thread_id = (*memnode)->ThreadId();
mem_python_node->increase_bytes = (*memnode)->IncreaseBytes();
mem_python_node->place = (*memnode)->Place();
mem_python_node->current_allocated = (*memnode)->CurrentAllocated();
mem_python_node->current_reserved = (*memnode)->CurrentReserved();
host_python_node->mem_node_ptrs.push_back(mem_python_node);
}
// copy OperatorSupplementEventNode's information if exists
OperatorSupplementEventNode* op_supplement_node =
root->GetOperatorSupplementEventNode();
if (op_supplement_node != nullptr) {
host_python_node->input_shapes = op_supplement_node->InputShapes();
host_python_node->dtypes = op_supplement_node->Dtypes();
host_python_node->callstack = op_supplement_node->CallStack();
}
return host_python_node;
}
......
......@@ -43,6 +43,31 @@ struct DevicePythonNode {
uint64_t stream_id;
};
struct MemPythonNode {
MemPythonNode() = default;
~MemPythonNode() {}
// timestamp of the record
uint64_t timestamp_ns;
// memory addr of allocation or free
uint64_t addr;
// memory manipulation type
TracerMemEventType type;
// process id of the record
uint64_t process_id;
// thread id of the record
uint64_t thread_id;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
int64_t increase_bytes;
// place
std::string place;
// current total allocated memory
uint64_t current_allocated;
// current total reserved memory
uint64_t current_reserved;
};
struct HostPythonNode {
HostPythonNode() = default;
~HostPythonNode();
......@@ -58,12 +83,19 @@ struct HostPythonNode {
uint64_t process_id;
// thread id of the record
uint64_t thread_id;
// input shapes
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
// call stack
std::string callstack;
// children node
std::vector<HostPythonNode*> children_node_ptrs;
// runtime node
std::vector<HostPythonNode*> runtime_node_ptrs;
// device node
std::vector<DevicePythonNode*> device_node_ptrs;
// mem node
std::vector<MemPythonNode*> mem_node_ptrs;
};
class ProfilerResult {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册