未验证 提交 c3e1d257 编写于 作者: L liutiexing 提交者: GitHub

Update Profiler (#41638)

上级 c055b50c
文件模式从 100755 更改为 100644
......@@ -152,7 +152,7 @@ TEST(SerializationLoggerTest, dump_case1) {
TEST(DeserializationReaderTest, restore_case0) {
DeserializationReader reader("test_serialization_logger_case0.pb");
auto profiler_result = reader.Parse();
auto& tree = profiler_result->GetNodeTrees();
auto tree = profiler_result->GetNodeTrees();
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree->Traverse(true);
EXPECT_EQ(nodes[10].size(), 4u);
......@@ -179,7 +179,7 @@ TEST(DeserializationReaderTest, restore_case0) {
TEST(DeserializationReaderTest, restore_case1) {
DeserializationReader reader("test_serialization_logger_case1.pb");
auto profiler_result = reader.Parse();
auto& tree = profiler_result->GetNodeTrees();
auto tree = profiler_result->GetNodeTrees();
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree->Traverse(true);
EXPECT_EQ(nodes[10].size(), 1u);
......
......@@ -103,7 +103,7 @@ class CudaRuntimeTraceEventNode {
device_node_ptrs_.push_back(node);
}
void LogMe(BaseLogger* logger) { logger->LogRuntimeTraceEventNode(*this); }
std::vector<DeviceTraceEventNode*>& GetDeviceTraceEventNodes() {
const std::vector<DeviceTraceEventNode*>& GetDeviceTraceEventNodes() const {
return device_node_ptrs_;
}
......@@ -139,8 +139,11 @@ class HostTraceEventNode {
void AddCudaRuntimeNode(CudaRuntimeTraceEventNode* node) {
runtime_node_ptrs_.push_back(node);
}
std::vector<HostTraceEventNode*>& GetChildren() { return children_; }
std::vector<CudaRuntimeTraceEventNode*>& GetRuntimeTraceEventNodes() {
const std::vector<HostTraceEventNode*>& GetChildren() const {
return children_;
}
const std::vector<CudaRuntimeTraceEventNode*>& GetRuntimeTraceEventNodes()
const {
return runtime_node_ptrs_;
}
void LogMe(BaseLogger* logger) { logger->LogHostTraceEventNode(*this); }
......@@ -188,7 +191,7 @@ class NodeTrees {
void HandleTrees(std::function<void(HostTraceEventNode*)>,
std::function<void(CudaRuntimeTraceEventNode*)>,
std::function<void(DeviceTraceEventNode*)>);
std::map<uint64_t, HostTraceEventNode*> GetNodeTrees() {
const std::map<uint64_t, HostTraceEventNode*>& GetNodeTrees() const {
return thread_event_trees_map_;
}
std::map<uint64_t, std::vector<HostTraceEventNode*>> Traverse(bool bfs) const;
......
......@@ -81,7 +81,7 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) {
ProfilerResult::ProfilerResult(std::unique_ptr<NodeTrees> tree,
const ExtraInfo& extra_info)
: tree_(std::move(tree)), extra_info_(extra_info) {
: tree_(tree.release()), extra_info_(extra_info) {
if (tree_ != nullptr) {
std::map<uint64_t, HostTraceEventNode*> nodetrees = tree_->GetNodeTrees();
for (auto it = nodetrees.begin(); it != nodetrees.end(); ++it) {
......
......@@ -82,11 +82,11 @@ class ProfilerResult {
void Save(const std::string& file_name,
const std::string format = std::string("json"));
std::unique_ptr<NodeTrees>& GetNodeTrees() { return tree_; }
std::shared_ptr<NodeTrees> GetNodeTrees() { return tree_; }
private:
std::map<uint64_t, HostPythonNode*> thread_event_trees_map_;
std::unique_ptr<NodeTrees> tree_;
std::shared_ptr<NodeTrees> tree_;
ExtraInfo extra_info_;
HostPythonNode* CopyTree(HostTraceEventNode* root);
};
......
......@@ -46,7 +46,7 @@ TEST(ProfilerTest, TestHostTracer) {
3);
}
auto profiler_result = profiler->Stop();
auto& nodetree = profiler_result->GetNodeTrees();
auto nodetree = profiler_result->GetNodeTrees();
std::set<std::string> host_events;
for (const auto pair : nodetree->Traverse(true)) {
for (const auto evt : pair.second) {
......@@ -79,7 +79,7 @@ TEST(ProfilerTest, TestCudaTracer) {
hipStreamSynchronize(stream);
#endif
auto profiler_result = profiler->Stop();
auto& nodetree = profiler_result->GetNodeTrees();
auto nodetree = profiler_result->GetNodeTrees();
std::vector<std::string> runtime_events;
for (const auto pair : nodetree->Traverse(true)) {
for (const auto host_node : pair.second) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册