From cbe7466ff733e917b84e65a423ced310d56ac20e Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Thu, 14 Apr 2022 19:01:32 +0800 Subject: [PATCH] executor perf statistics (#41648) * executor perf statistics * fix ut * fix ut * fix ut * add ut * add ut --- .../framework/new_executor/CMakeLists.txt | 2 + .../new_executor/executor_statistics.cc | 627 ++++++++++++++++++ .../new_executor/executor_statistics.h | 27 + .../new_executor/standalone_executor.cc | 7 + .../new_executor/workqueue/CMakeLists.txt | 2 +- .../workqueue/nonblocking_threadpool.h | 9 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 9 +- .../unittests/interpreter/CMakeLists.txt | 2 +- .../interpreter/test_standalone_executor.py | 105 +++ 10 files changed, 782 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/new_executor/executor_statistics.cc create mode 100644 paddle/fluid/framework/new_executor/executor_statistics.h diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index b7b09da5ce..6046000739 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -20,6 +20,8 @@ endif() cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) +cc_library(staticgraph_executor_statistics SRCS executor_statistics.cc DEPS enforce glog os_info) + # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) # skip win32 since wget is not installed by default on windows machine. if (WITH_GPU AND WITH_TESTING AND NOT WIN32 AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON") diff --git a/paddle/fluid/framework/new_executor/executor_statistics.cc b/paddle/fluid/framework/new_executor/executor_statistics.cc new file mode 100644 index 0000000000..392d6c78f9 --- /dev/null +++ b/paddle/fluid/framework/new_executor/executor_statistics.cc @@ -0,0 +1,627 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/executor_statistics.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/platform/flags.h" +#include "paddle/fluid/platform/os_info.h" +#include "paddle/fluid/platform/profiler/utils.h" + +DECLARE_bool(use_stream_safe_cuda_allocator); +PADDLE_DEFINE_EXPORTED_string(static_executor_perfstat_filepath, "", + "FLAGS_static_executor_perfstat_filepath " + "enables performance statistics for the static " + "graph executor."); + +namespace paddle { +namespace framework { + +class StatisticsEngine { + public: + int Apply(const platform::NodeTrees& trees); + + void Log(const std::string& full_filename); + + private: + // type + struct EventStat { + uint64_t total_time = 0; + size_t count = 0; + uint64_t normalization_time = 0; + }; + + struct Priority { + // use a smaller number to denote higher priority + int innerthread_priority = 0; + int interthread_priority = 0; + }; + + struct StdEvent { + size_t evt_idx; + uint64_t start_ns; + uint64_t end_ns; + + StdEvent(size_t idx, uint64_t start, uint64_t end) + : evt_idx(idx), start_ns(start), end_ns(end) {} + }; + + enum class ExecutorType { EXECUTOR, PARALLEL_EXECUTOR, INTERPRETER_CORE }; + + using Filter = std::function; + + int Init(const platform::NodeTrees& trees); + + int Stat(const platform::NodeTrees& trees); + + void InitStdEvents(); + + void InitInnerthreadPriorityForStdEvents(); + + void InitInterthreadPriorityForStdEvents(); + + int InitFiltersForExecutor(); + + int InitFiltersForParallelExecutor(); + + int InitFiltersForInterpreterCore(); + + int RegisterEventFilter(const std::string& std_event, Filter filter) { + auto iter = name2idx_.find(std_event); + if (iter == name2idx_.end()) { + LOG(WARNING) << "Unsupported std_event " << std_event; + return -1; + } + auto idx = iter->second; + if (filters_[idx]) { + LOG(WARNING) << "Duplicate registration for std_event(" << std_event + << ")"; + return -1; + } + filters_[idx] = std::move(filter); + return 0; + } + + void MergeEvents(std::function merger, + std::vector* in_out_evts); + + int MergeInnerthreadEvents(std::vector>* all_evts); + + int MergeInterthreadEvents(std::vector>* all_evts); + + int StatNormalizationTime(const std::vector>& all_evts); + + bool inited_ = false; + ExecutorType executor_type_; + std::vector names_; + std::vector filters_; + std::vector priorities_; + std::vector statistics_; + std::unordered_map name2idx_; +}; + +int StatisticsEngine::Apply(const platform::NodeTrees& tree) { + return Init(tree) || Stat(tree); +} + +int StatisticsEngine::Init(const platform::NodeTrees& trees) { + if (inited_) { + LOG(WARNING) << "Duplicate initialization for StatisticsEngine"; + return -1; + } + if (platform::GetCurrentThreadName() != "MainThread") { + LOG(WARNING) << "StatisticsEngin must run on the main thread"; + return -1; + } + inited_ = true; + InitStdEvents(); + InitInnerthreadPriorityForStdEvents(); + InitInterthreadPriorityForStdEvents(); + // determine executor type + uint64_t main_tid = platform::GetCurrentThreadId().sys_tid; + for (const auto& kv : trees.GetNodeTrees()) { + if (kv.first != main_tid) { + continue; + } + std::queue q; + q.push(kv.second); + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + const auto& name = cur_node->Name(); + if (name.find("Executor::") == 0) { + VLOG(10) << "type: Executor"; + executor_type_ = ExecutorType::EXECUTOR; + return InitFiltersForExecutor(); + } else if (name.find("ParallelExecutor::") == 0) { + VLOG(10) << "type: ParallelExecutor"; + executor_type_ = ExecutorType::PARALLEL_EXECUTOR; + return InitFiltersForParallelExecutor(); + } else if (name.find("StandaloneExecutor::") == 0) { + VLOG(10) << "type: InterpreterCore"; + executor_type_ = ExecutorType::INTERPRETER_CORE; + return InitFiltersForInterpreterCore(); + } + for (const auto& child : cur_node->GetChildren()) { + q.push(child); + } + } + } + LOG(WARNING) << "Unsupported Executor"; + return -1; +} + +void StatisticsEngine::InitStdEvents() { + name2idx_["Total"] = names_.size(); + names_.push_back("Total"); + name2idx_["PythonEnd"] = names_.size(); + names_.push_back("PythonEnd"); + name2idx_["CplusplusEnd"] = names_.size(); + names_.push_back("CplusplusEnd"); + name2idx_["RunOp"] = names_.size(); + names_.push_back("RunOp"); + name2idx_["LuanchKernel"] = names_.size(); + names_.push_back("LuanchKernel"); + name2idx_["OpCompute"] = names_.size(); + names_.push_back("OpCompute"); + name2idx_["OpInfershape"] = names_.size(); + names_.push_back("OpInfershape"); + name2idx_["DataTransform"] = names_.size(); + names_.push_back("DataTransform"); + name2idx_["GarbageCollect"] = names_.size(); + names_.push_back("GarbageCollect"); + name2idx_["CalcNextOp"] = names_.size(); + names_.push_back("CalcNextOp"); + name2idx_["AllocateDeviceMem"] = names_.size(); + names_.push_back("AllocateDeviceMem"); + name2idx_["FreeDeviceMem"] = names_.size(); + names_.push_back("FreeDeviceMem"); + name2idx_["ThreadpoolAddTask"] = names_.size(); + names_.push_back("ThreadpoolAddTask"); + + size_t n = names_.size(); + filters_.resize(n); + priorities_.resize(n); + statistics_.resize(n); +} + +void StatisticsEngine::InitInnerthreadPriorityForStdEvents() { + int prio = 0; + priorities_[name2idx_["AllocateDeviceMem"]].innerthread_priority = ++prio; + priorities_[name2idx_["FreeDeviceMem"]].innerthread_priority = prio; + priorities_[name2idx_["ThreadpoolAddTask"]].innerthread_priority = prio; + + priorities_[name2idx_["CalcNextOp"]].innerthread_priority = ++prio; + priorities_[name2idx_["GarbageCollect"]].innerthread_priority = prio; + priorities_[name2idx_["OpCompute"]].innerthread_priority = prio; + priorities_[name2idx_["OpInfershape"]].innerthread_priority = prio; + priorities_[name2idx_["DataTransform"]].innerthread_priority = prio; + + priorities_[name2idx_["RunOp"]].innerthread_priority = ++prio; + + priorities_[name2idx_["CplusplusEnd"]].innerthread_priority = ++prio; + + priorities_[name2idx_["Total"]].innerthread_priority = ++prio; +} + +void StatisticsEngine::InitInterthreadPriorityForStdEvents() { + int prio = 0; + priorities_[name2idx_["LuanchKernel"]].interthread_priority = ++prio; + priorities_[name2idx_["AllocateDeviceMem"]].interthread_priority = ++prio; + priorities_[name2idx_["FreeDeviceMem"]].interthread_priority = ++prio; + priorities_[name2idx_["ThreadpoolAddTask"]].interthread_priority = ++prio; + + priorities_[name2idx_["CalcNextOp"]].interthread_priority = ++prio; + priorities_[name2idx_["GarbageCollect"]].interthread_priority = ++prio; + priorities_[name2idx_["OpInfershape"]].interthread_priority = ++prio; + priorities_[name2idx_["DataTransform"]].interthread_priority = ++prio; + + priorities_[name2idx_["RunOp"]].interthread_priority = ++prio; + priorities_[name2idx_["CplusplusEnd"]].interthread_priority = ++prio; + priorities_[name2idx_["PythonEnd"]].interthread_priority = prio; +} + +const char* alloc_device_mem = FLAGS_use_stream_safe_cuda_allocator + ? "StreamSafeCUDAAllocator::Allocate" + : "AutoGrowthBestFitAllocator::Allocate"; +const char* free_device_mem = FLAGS_use_stream_safe_cuda_allocator + ? "StreamSafeCUDAAllocator::Free" + : "AutoGrowthBestFitAllocator::Free"; + +int StatisticsEngine::InitFiltersForExecutor() { + return RegisterEventFilter("Total", + [](const platform::HostTraceEventNode& evt) { + return evt.Name().find("ProfileStep") == 0; + }) || + RegisterEventFilter("CplusplusEnd", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == + "Executor::RunPartialPreparedContext"; + }) || + RegisterEventFilter("RunOp", + [](const platform::HostTraceEventNode& evt) { + return evt.Type() == + platform::TracerEventType::Operator; + }) || + RegisterEventFilter( + "OpCompute", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "compute" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter( + "OpInfershape", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "infer_shape" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter("GarbageCollect", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "CheckGC"; + }) || + RegisterEventFilter("AllocateDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == alloc_device_mem; + }) || + RegisterEventFilter("FreeDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == free_device_mem; + }) || + RegisterEventFilter( + "DataTransform", [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "prepare_data" && + evt.Type() == platform::TracerEventType::OperatorInner; + }); +} + +int StatisticsEngine::InitFiltersForParallelExecutor() { + return RegisterEventFilter("Total", + [](const platform::HostTraceEventNode& evt) { + return evt.Name().find("ProfileStep") == 0; + }) || + RegisterEventFilter("CplusplusEnd", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "ParallelExecutor::Run"; + }) || + RegisterEventFilter("RunOp", + [](const platform::HostTraceEventNode& evt) { + return evt.Type() == + platform::TracerEventType::Operator; + }) || + RegisterEventFilter( + "OpCompute", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "compute" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter( + "OpInfershape", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "infer_shape" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter("GarbageCollect", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "eager_deletion" || + evt.Name() == "CheckGC"; + }) || + RegisterEventFilter("AllocateDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == alloc_device_mem; + }) || + RegisterEventFilter("FreeDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == free_device_mem; + }) || + RegisterEventFilter( + "DataTransform", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "prepare_data" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter("ThreadpoolAddTask", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "WorkQueue::AddTask"; + }); +} + +int StatisticsEngine::InitFiltersForInterpreterCore() { + return RegisterEventFilter("Total", + [](const platform::HostTraceEventNode& evt) { + return evt.Name().find("ProfileStep") == 0; + }) || + RegisterEventFilter("CplusplusEnd", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "StandaloneExecutor::run"; + }) || + RegisterEventFilter("RunOp", + [](const platform::HostTraceEventNode& evt) { + return evt.Type() == + platform::TracerEventType::Operator; + }) || + RegisterEventFilter( + "OpCompute", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "compute" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter( + "OpInfershape", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "infer_shape" && + evt.Type() == platform::TracerEventType::OperatorInner; + }) || + RegisterEventFilter("GarbageCollect", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "CheckGC" || + evt.Name() == "RecordStreamForGC"; + }) || + RegisterEventFilter("AllocateDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == alloc_device_mem; + }) || + RegisterEventFilter("FreeDeviceMem", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == free_device_mem; + }) || + RegisterEventFilter("CalcNextOp", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "RunNextInstructions"; + }) || + RegisterEventFilter("ThreadpoolAddTask", + [](const platform::HostTraceEventNode& evt) { + return evt.Name() == "WorkQueue::AddTask"; + }); +} + +int StatisticsEngine::Stat(const platform::NodeTrees& trees) { + // Convert StdEvent + std::vector> all_evts; + for (const auto& tree : trees.GetNodeTrees()) { + std::vector thr_evts; + std::queue q; + q.push(tree.second); + std::unordered_set removed; + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + for (const auto& child : cur_node->GetChildren()) { + // Remove duplicate operator records. + // See InterpreterCore::RunInstruction for details. + if (child->Type() == platform::TracerEventType::Operator && + cur_node->Name() == "compute") { + removed.insert(child); + } + q.push(child); + } + if (removed.count(cur_node) > 0) { + VLOG(10) << "Remove duplicate operator record: " << cur_node->Name(); + continue; + } + for (size_t idx = 0; idx < filters_.size(); ++idx) { + if (!filters_[idx]) { + continue; + } + if (filters_[idx](*cur_node)) { + thr_evts.emplace_back(idx, cur_node->StartNs(), cur_node->EndNs()); + break; + } + } + } + if (thr_evts.size() == 0) { + continue; + } + std::sort(thr_evts.begin(), thr_evts.end(), + [](const StdEvent& e1, const StdEvent& e2) { + return e1.start_ns < e2.start_ns; + }); + all_evts.push_back(std::move(thr_evts)); + } + if (all_evts.size() == 0) { + LOG(WARNING) << "No profiler events"; + return -1; + } + + // statistic total_time/count + for (const auto& thr_evts : all_evts) { + for (const auto& evt : thr_evts) { + auto& evt_stat = statistics_[evt.evt_idx]; + evt_stat.total_time += evt.end_ns - evt.start_ns; + evt_stat.count += 1; + } + } + auto& python_end = statistics_[name2idx_["PythonEnd"]]; + const auto& totol = statistics_[name2idx_["Total"]]; + const auto& cplusplus_end = statistics_[name2idx_["CplusplusEnd"]]; + python_end.total_time = totol.total_time - cplusplus_end.total_time; + python_end.count = cplusplus_end.count + 1; + + auto& luanch_kernel = statistics_[name2idx_["LuanchKernel"]]; + const auto& op_compute = statistics_[name2idx_["OpCompute"]]; + const auto& allocate = statistics_[name2idx_["AllocateDeviceMem"]]; + luanch_kernel.total_time = op_compute.total_time - allocate.total_time; + luanch_kernel.count = op_compute.count; + + if (executor_type_ != ExecutorType::EXECUTOR && + statistics_[name2idx_["ThreadpoolAddTask"]].count == 0) { + LOG(WARNING) << "Check your env variable FLAGS_host_trace_level, make sure " + "FLAGS_host_trace_level >= 10."; + return -1; + } + + // statistic normalization_time + return MergeInnerthreadEvents(&all_evts) || + MergeInterthreadEvents(&all_evts) || StatNormalizationTime(all_evts); +} + +void StatisticsEngine::MergeEvents(std::function merger, + std::vector* in_out_evts) { + auto evts = *in_out_evts; + std::sort(evts.begin(), evts.end(), + [](const StdEvent& e1, const StdEvent& e2) { + return e1.start_ns < e2.start_ns; + }); + + std::list merged; + auto iter = merged.begin(); + for (size_t i = 0; i < evts.size();) { + if (iter == merged.end()) { + iter = merged.insert(iter, evts[i]); + ++i; + } else if (iter->end_ns <= evts[i].start_ns) { + ++iter; + } else if (iter->evt_idx == evts[i].evt_idx) { + iter->end_ns = std::max(iter->end_ns, evts[i].end_ns); + ++i; + } else { + auto merged_type = merger(iter->evt_idx, evts[i].evt_idx); + if (merged_type == iter->evt_idx) { + if (evts[i].end_ns > iter->end_ns) { + evts[i].start_ns = iter->end_ns; + ++iter; + } else { + ++i; + } + } else { + StdEvent back = *iter; + if (back.start_ns != evts[i].start_ns) { + merged.insert(iter, {back.evt_idx, back.start_ns, evts[i].start_ns}); + } + *iter = evts[i]; + if (back.end_ns > evts[i].end_ns) { + auto pos = iter; + merged.insert(++pos, {back.evt_idx, evts[i].end_ns, back.end_ns}); + } + ++i; + } + } + } + in_out_evts->assign(merged.begin(), merged.end()); +} + +int StatisticsEngine::MergeInnerthreadEvents( + std::vector>* all_evts) { + auto merger = [& priorities = priorities_](size_t idx1, size_t idx2) { + return priorities[idx1].innerthread_priority <= + priorities[idx2].innerthread_priority + ? idx1 + : idx2; + }; + for (auto& thr_evts : *all_evts) { + MergeEvents(merger, &thr_evts); + for (auto& evt : thr_evts) { + if (names_[evt.evt_idx] == "Total") { + evt.evt_idx = name2idx_["PythonEnd"]; + } else if (names_[evt.evt_idx] == "OpCompute") { + evt.evt_idx = name2idx_["LuanchKernel"]; + } + } + } + return 0; +} + +int StatisticsEngine::MergeInterthreadEvents( + std::vector>* all_evts) { + auto merger = [& priorities = priorities_](size_t idx1, size_t idx2) { + return priorities[idx1].interthread_priority <= + priorities[idx2].interthread_priority + ? idx1 + : idx2; + }; + // K-way merge, just simplest impl + std::vector base_list; + base_list.swap(all_evts->at(0)); + for (size_t i = 1; i < all_evts->size(); ++i) { + auto& cur_list = all_evts->at(i); + base_list.reserve(base_list.size() + cur_list.size()); + base_list.insert(base_list.end(), cur_list.begin(), cur_list.end()); + MergeEvents(merger, &base_list); + } + all_evts->resize(1); + (*all_evts)[0].swap(base_list); + return 0; +} + +int StatisticsEngine::StatNormalizationTime( + const std::vector>& all_evts) { + if (all_evts.size() != 1) { + LOG(WARNING) << "Invalid argument"; + return -1; + } + for (const auto& evt : all_evts[0]) { + statistics_[evt.evt_idx].normalization_time += evt.end_ns - evt.start_ns; + } + // verify + uint64_t total = statistics_[name2idx_["Total"]].total_time; + uint64_t normalization_sum = 0; + for (size_t idx = 0; idx < statistics_.size(); ++idx) { + normalization_sum += statistics_[idx].normalization_time; + } + if (total - normalization_sum != 0) { + LOG(WARNING) << "total: " << total + << "is greater than normalization_sum:" << normalization_sum; + return -1; + } + return 0; +} + +void StatisticsEngine::Log(const std::string& filepath) { + std::ofstream ofs; + ofs.open(filepath, std::ofstream::out | std::ofstream::trunc); + if (!ofs) { + LOG(WARNING) << "Unable to open file " << filepath << " for writing data."; + return; + } + ofs << "["; + for (size_t idx = 0; idx < statistics_.size(); ++idx) { + const auto& evt_stat = statistics_[idx]; + ofs << platform::string_format(std::string(R"JSON( + { + "statistical item" : "%s", + "total time(ns)" : %llu, + "total number of times" : %llu, + "normalization time(ns)" : %llu + },)JSON"), + names_[idx].c_str(), evt_stat.total_time, + evt_stat.count, evt_stat.normalization_time); + } + ofs.seekp(-1, std::ios_base::end); + ofs << "]"; + if (ofs) { + LOG(INFO) << "writing the executor performance statistics to " << filepath; + } + ofs.close(); +} + +void StaticGraphExecutorPerfStatistics( + std::shared_ptr profiling_data) { + if (FLAGS_static_executor_perfstat_filepath.size() == 0) { + VLOG(5) << "StaticGraphExecutorPerfStatistics is disabled"; + return; + } + StatisticsEngine engine; + if (engine.Apply(*profiling_data) == 0) { + engine.Log(FLAGS_static_executor_perfstat_filepath); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/executor_statistics.h b/paddle/fluid/framework/new_executor/executor_statistics.h new file mode 100644 index 0000000000..530e945596 --- /dev/null +++ b/paddle/fluid/framework/new_executor/executor_statistics.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/profiler/event_node.h" + +namespace paddle { +namespace framework { + +void StaticGraphExecutorPerfStatistics( + std::shared_ptr profiling_data); + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index a225023147..4d4f7c74cd 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/new_executor/interpretercore_util.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" namespace paddle { namespace framework { @@ -59,6 +60,9 @@ paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& feed_names, const std::vector& feed_tensors, const std::vector& fetch_names) { + platform::RecordEvent record_event("StandaloneExecutor::run", + platform::TracerEventType::UserDefined, 1); + auto core = GetInterpreterCore(feed_names, fetch_names, true); return core->Run(feed_names, feed_tensors); @@ -67,6 +71,9 @@ paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& feed_names, const std::vector& fetch_names) { + platform::RecordEvent record_event("StandaloneExecutor::run", + platform::TracerEventType::UserDefined, 1); + auto core = GetInterpreterCore(feed_names, fetch_names, false); VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; return core->Run(feed_names); diff --git a/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt b/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt index f47a274aaa..2690b29e01 100644 --- a/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt @@ -1,3 +1,3 @@ cc_library(workqueue_utils SRCS workqueue_utils.cc events_waiter.cc DEPS enforce glog) -cc_library(workqueue SRCS workqueue.cc DEPS workqueue_utils enforce glog) +cc_library(workqueue SRCS workqueue.cc DEPS workqueue_utils enforce glog os_info) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) diff --git a/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h index 44953fa192..a599bc41f6 100644 --- a/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h @@ -129,6 +129,7 @@ class ThreadPoolTempl { // this. We expect that such scenario is prevented by program, that is, // this is kept alive while any threads can potentially be in Schedule. if (!t.f) { + // Allow 'false positive' which makes a redundant notification. if (num_tasks > num_threads_ - blocked_) { VLOG(6) << "Add task, Notify"; ec_.Notify(false); @@ -379,9 +380,8 @@ class ThreadPoolTempl { return false; } - // Number of blocked threads is used as termination condition. - // If we are shutting down and all worker threads blocked without work, - // that's we are done. + // Number of blocked threads is used as notification condition. + // We must increase the counter before the emptiness check. blocked_++; // Now do a reliable emptiness check. @@ -393,6 +393,9 @@ class ThreadPoolTempl { return true; } + // Number of blocked threads is used as termination condition. + // If we are shutting down and all worker threads blocked without work, + // that's we are done. if (done_ && blocked_ == static_cast(num_threads_)) { ec_.CancelWait(); // Almost done, but need to re-check queues. diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index f8e7081de0..42eb79d75f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -350,7 +350,7 @@ if(WITH_PYTHON) add_custom_target(eager_op_function_generator_cmd ALL DEPENDS ${eager_impl_file}) endif() - list(APPEND PYBIND_DEPS interpretercore standalone_executor) + list(APPEND PYBIND_DEPS interpretercore standalone_executor staticgraph_executor_statistics) cc_library(op_function_common SRCS op_function_common.cc DEPS ${PYBIND_DEPS}) list(APPEND PYBIND_DEPS op_function_common) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0427fcece0..7b63fdd6dd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -46,6 +46,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/new_executor/executor_statistics.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" @@ -2903,9 +2904,6 @@ All parameter, weight, gradient are variables in Paddle. .def("run", [](StandaloneExecutor &self, std::vector feed_names, std::vector fetch_names) { - platform::RecordEvent record_event( - "StandaloneExecutor::run", - platform::TracerEventType::UserDefined, 1); paddle::framework::FetchList ret; { pybind11::gil_scoped_release release; @@ -3380,7 +3378,10 @@ All parameter, weight, gradient are variables in Paddle. .def("stop", [](paddle::platform::Profiler *profiler) { platform::DisableHostEventRecorder(); - return profiler->Stop(); + auto result = profiler->Stop(); + framework::StaticGraphExecutorPerfStatistics( + result->GetNodeTrees()); + return result; }, py::return_value_policy::automatic_reference); diff --git a/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt b/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt index c1a2c36d8a..09cc6ed5b5 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt @@ -2,7 +2,7 @@ file(GLOB TEST_INTERP_CASES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") foreach(target ${TEST_INTERP_CASES}) - py_test_modules(${target} MODULES ${target} ENVS FLAGS_allocator_strategy=auto_growth FLAGS_use_stream_safe_cuda_allocator=true FLAGS_fast_eager_deletion_mode=false FLAGS_eager_delete_tensor_gb=0) + py_test_modules(${target} MODULES ${target} ENVS FLAGS_host_trace_level=10 FLAGS_static_executor_perfstat_filepath=./perfstat FLAGS_allocator_strategy=auto_growth FLAGS_use_stream_safe_cuda_allocator=true FLAGS_fast_eager_deletion_mode=false FLAGS_eager_delete_tensor_gb=0) py_test_modules(${target}_non_eager_deletion MODULES ${target} ENVS FLAGS_allocator_strategy=auto_growth FLAGS_use_stream_safe_cuda_allocator=true FLAGS_fast_eager_deletion_mode=false FLAGS_eager_delete_tensor_gb=0.000001) py_test_modules(${target}_fast_gc MODULES ${target} ENVS FLAGS_allocator_strategy=auto_growth FLAGS_use_stream_safe_cuda_allocator=true FLAGS_fast_eager_deletion_mode=true FLAGS_eager_delete_tensor_gb=0) py_test_modules(${target}_fast_gc_non_eager_deletion MODULES ${target} ENVS FLAGS_allocator_strategy=auto_growth FLAGS_use_stream_safe_cuda_allocator=true FLAGS_fast_eager_deletion_mode=true FLAGS_eager_delete_tensor_gb=0.000001) diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index c07d4cc15b..a4dad5f53f 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -15,10 +15,13 @@ import os os.environ['FLAGS_use_stream_safe_cuda_allocator'] = "true" import sys +import shutil import unittest import paddle +import json from paddle.fluid import core from paddle.fluid.core import StandaloneExecutor +from paddle.profiler import profiler import numpy as np @@ -116,6 +119,107 @@ def build_program(): return main_program, startup_program, [mean] +class ExecutorStatisticsTestCase(unittest.TestCase): + def setUp(self): + self.iter_n = 3 + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def test_standalone_executor_statistics(self): + if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: + return + + paddle.seed(2020) + main_program, startup_program, fetch_list = build_program() + fetch_list = [x.name for x in fetch_list] + + p = core.Place() + p.set_place(self.place) + executor = StandaloneExecutor(p, startup_program.desc, + main_program.desc, core.Scope()) + + helper_profiler = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) + helper_profiler.start() + for i in range(self.iter_n): + executor.run({}, fetch_list) + helper_profiler.step() + helper_profiler.stop() + + perfstat_filepath = os.environ[ + 'FLAGS_static_executor_perfstat_filepath'] + self.assertTrue(os.path.exists(perfstat_filepath)) + with open(perfstat_filepath, 'r') as load_f: + stat_res = json.load(load_f) + self.assertTrue(len(stat_res) > 0) + + os.remove(perfstat_filepath) + shutil.rmtree('./profiler_log') + + def test_parallel_executor_statistics(self): + if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: + return + + paddle.seed(2020) + main_program, startup_program, fetch_list = build_program() + fetch_list = [x.name for x in fetch_list] + + main_program = paddle.fluid.compiler.CompiledProgram(main_program) + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' + executor = paddle.static.Executor(self.place) + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + executor.run(startup_program) + + helper_profiler = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) + helper_profiler.start() + for i in range(self.iter_n): + executor.run(main_program, fetch_list=fetch_list) + helper_profiler.step() + helper_profiler.stop() + + perfstat_filepath = os.environ[ + 'FLAGS_static_executor_perfstat_filepath'] + self.assertTrue(os.path.exists(perfstat_filepath)) + with open(perfstat_filepath, 'r') as load_f: + stat_res = json.load(load_f) + self.assertTrue(len(stat_res) > 0) + + os.remove(perfstat_filepath) + shutil.rmtree('./profiler_log') + + def test_executor_statistics(self): + if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: + return + + paddle.seed(2020) + main_program, startup_program, fetch_list = build_program() + fetch_list = [x.name for x in fetch_list] + + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' + executor = paddle.static.Executor(self.place) + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + executor.run(startup_program) + + helper_profiler = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) + helper_profiler.start() + for i in range(self.iter_n): + executor.run(main_program, fetch_list=fetch_list) + helper_profiler.step() + helper_profiler.stop() + + perfstat_filepath = os.environ[ + 'FLAGS_static_executor_perfstat_filepath'] + self.assertTrue(os.path.exists(perfstat_filepath)) + with open(perfstat_filepath, 'r') as load_f: + stat_res = json.load(load_f) + self.assertTrue(len(stat_res) > 0) + + os.remove(perfstat_filepath) + shutil.rmtree('./profiler_log') + + class MultiStreamModelTestCase(unittest.TestCase): def setUp(self): self.iter_n = 2 @@ -155,6 +259,7 @@ class MultiStreamModelTestCase(unittest.TestCase): p.set_place(self.place) inter_core = StandaloneExecutor(p, startup_program.desc, main_program.desc, core.Scope()) + outs = [] for i in range(self.iter_n): outs.append( -- GitLab