未验证 提交 c3e1d257 编写于 作者: L liutiexing 提交者: GitHub

Update Profiler (#41638)

上级 c055b50c
文件模式从 100755 更改为 100644
...@@ -152,7 +152,7 @@ TEST(SerializationLoggerTest, dump_case1) { ...@@ -152,7 +152,7 @@ TEST(SerializationLoggerTest, dump_case1) {
TEST(DeserializationReaderTest, restore_case0) { TEST(DeserializationReaderTest, restore_case0) {
DeserializationReader reader("test_serialization_logger_case0.pb"); DeserializationReader reader("test_serialization_logger_case0.pb");
auto profiler_result = reader.Parse(); auto profiler_result = reader.Parse();
auto& tree = profiler_result->GetNodeTrees(); auto tree = profiler_result->GetNodeTrees();
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes = std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree->Traverse(true); tree->Traverse(true);
EXPECT_EQ(nodes[10].size(), 4u); EXPECT_EQ(nodes[10].size(), 4u);
...@@ -179,7 +179,7 @@ TEST(DeserializationReaderTest, restore_case0) { ...@@ -179,7 +179,7 @@ TEST(DeserializationReaderTest, restore_case0) {
TEST(DeserializationReaderTest, restore_case1) { TEST(DeserializationReaderTest, restore_case1) {
DeserializationReader reader("test_serialization_logger_case1.pb"); DeserializationReader reader("test_serialization_logger_case1.pb");
auto profiler_result = reader.Parse(); auto profiler_result = reader.Parse();
auto& tree = profiler_result->GetNodeTrees(); auto tree = profiler_result->GetNodeTrees();
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes = std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree->Traverse(true); tree->Traverse(true);
EXPECT_EQ(nodes[10].size(), 1u); EXPECT_EQ(nodes[10].size(), 1u);
......
...@@ -103,7 +103,7 @@ class CudaRuntimeTraceEventNode { ...@@ -103,7 +103,7 @@ class CudaRuntimeTraceEventNode {
device_node_ptrs_.push_back(node); device_node_ptrs_.push_back(node);
} }
void LogMe(BaseLogger* logger) { logger->LogRuntimeTraceEventNode(*this); } void LogMe(BaseLogger* logger) { logger->LogRuntimeTraceEventNode(*this); }
std::vector<DeviceTraceEventNode*>& GetDeviceTraceEventNodes() { const std::vector<DeviceTraceEventNode*>& GetDeviceTraceEventNodes() const {
return device_node_ptrs_; return device_node_ptrs_;
} }
...@@ -139,8 +139,11 @@ class HostTraceEventNode { ...@@ -139,8 +139,11 @@ class HostTraceEventNode {
void AddCudaRuntimeNode(CudaRuntimeTraceEventNode* node) { void AddCudaRuntimeNode(CudaRuntimeTraceEventNode* node) {
runtime_node_ptrs_.push_back(node); runtime_node_ptrs_.push_back(node);
} }
std::vector<HostTraceEventNode*>& GetChildren() { return children_; } const std::vector<HostTraceEventNode*>& GetChildren() const {
std::vector<CudaRuntimeTraceEventNode*>& GetRuntimeTraceEventNodes() { return children_;
}
const std::vector<CudaRuntimeTraceEventNode*>& GetRuntimeTraceEventNodes()
const {
return runtime_node_ptrs_; return runtime_node_ptrs_;
} }
void LogMe(BaseLogger* logger) { logger->LogHostTraceEventNode(*this); } void LogMe(BaseLogger* logger) { logger->LogHostTraceEventNode(*this); }
...@@ -188,7 +191,7 @@ class NodeTrees { ...@@ -188,7 +191,7 @@ class NodeTrees {
void HandleTrees(std::function<void(HostTraceEventNode*)>, void HandleTrees(std::function<void(HostTraceEventNode*)>,
std::function<void(CudaRuntimeTraceEventNode*)>, std::function<void(CudaRuntimeTraceEventNode*)>,
std::function<void(DeviceTraceEventNode*)>); std::function<void(DeviceTraceEventNode*)>);
std::map<uint64_t, HostTraceEventNode*> GetNodeTrees() { const std::map<uint64_t, HostTraceEventNode*>& GetNodeTrees() const {
return thread_event_trees_map_; return thread_event_trees_map_;
} }
std::map<uint64_t, std::vector<HostTraceEventNode*>> Traverse(bool bfs) const; std::map<uint64_t, std::vector<HostTraceEventNode*>> Traverse(bool bfs) const;
......
...@@ -81,7 +81,7 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { ...@@ -81,7 +81,7 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
ProfilerResult::ProfilerResult(std::unique_ptr<NodeTrees> tree, ProfilerResult::ProfilerResult(std::unique_ptr<NodeTrees> tree,
const ExtraInfo& extra_info) const ExtraInfo& extra_info)
: tree_(std::move(tree)), extra_info_(extra_info) { : tree_(tree.release()), extra_info_(extra_info) {
if (tree_ != nullptr) { if (tree_ != nullptr) {
std::map<uint64_t, HostTraceEventNode*> nodetrees = tree_->GetNodeTrees(); std::map<uint64_t, HostTraceEventNode*> nodetrees = tree_->GetNodeTrees();
for (auto it = nodetrees.begin(); it != nodetrees.end(); ++it) { for (auto it = nodetrees.begin(); it != nodetrees.end(); ++it) {
......
...@@ -82,11 +82,11 @@ class ProfilerResult { ...@@ -82,11 +82,11 @@ class ProfilerResult {
void Save(const std::string& file_name, void Save(const std::string& file_name,
const std::string format = std::string("json")); const std::string format = std::string("json"));
std::unique_ptr<NodeTrees>& GetNodeTrees() { return tree_; } std::shared_ptr<NodeTrees> GetNodeTrees() { return tree_; }
private: private:
std::map<uint64_t, HostPythonNode*> thread_event_trees_map_; std::map<uint64_t, HostPythonNode*> thread_event_trees_map_;
std::unique_ptr<NodeTrees> tree_; std::shared_ptr<NodeTrees> tree_;
ExtraInfo extra_info_; ExtraInfo extra_info_;
HostPythonNode* CopyTree(HostTraceEventNode* root); HostPythonNode* CopyTree(HostTraceEventNode* root);
}; };
......
...@@ -46,7 +46,7 @@ TEST(ProfilerTest, TestHostTracer) { ...@@ -46,7 +46,7 @@ TEST(ProfilerTest, TestHostTracer) {
3); 3);
} }
auto profiler_result = profiler->Stop(); auto profiler_result = profiler->Stop();
auto& nodetree = profiler_result->GetNodeTrees(); auto nodetree = profiler_result->GetNodeTrees();
std::set<std::string> host_events; std::set<std::string> host_events;
for (const auto pair : nodetree->Traverse(true)) { for (const auto pair : nodetree->Traverse(true)) {
for (const auto evt : pair.second) { for (const auto evt : pair.second) {
...@@ -79,7 +79,7 @@ TEST(ProfilerTest, TestCudaTracer) { ...@@ -79,7 +79,7 @@ TEST(ProfilerTest, TestCudaTracer) {
hipStreamSynchronize(stream); hipStreamSynchronize(stream);
#endif #endif
auto profiler_result = profiler->Stop(); auto profiler_result = profiler->Stop();
auto& nodetree = profiler_result->GetNodeTrees(); auto nodetree = profiler_result->GetNodeTrees();
std::vector<std::string> runtime_events; std::vector<std::string> runtime_events;
for (const auto pair : nodetree->Traverse(true)) { for (const auto pair : nodetree->Traverse(true)) {
for (const auto host_node : pair.second) { for (const auto host_node : pair.second) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册