提交 a8e85077 编写于 作者: L Liu Yiqun

Refine the profile codes for inference.

上级 b825c792
...@@ -74,6 +74,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -74,6 +74,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#endif #endif
} }
// profile
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
RunImpl(scope, place); RunImpl(scope, place);
} }
...@@ -497,9 +500,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -497,9 +500,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// profile
platform::RecordEvent record_event(Type(), dev_ctx);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
......
...@@ -115,11 +115,11 @@ void TestInference(const std::string& dirname, ...@@ -115,11 +115,11 @@ void TestInference(const std::string& dirname,
#endif #endif
} }
// Enable the profiler
paddle::platform::EnableProfiler(state);
// 2. Initialize the inference_program and load parameters // 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program; std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
// Enable the profiler
paddle::platform::EnableProfiler(state);
{ {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
"init_program", "init_program",
...@@ -143,6 +143,10 @@ void TestInference(const std::string& dirname, ...@@ -143,6 +143,10 @@ void TestInference(const std::string& dirname,
inference_program = paddle::inference::Load(executor, *scope, dirname); inference_program = paddle::inference::Load(executor, *scope, dirname);
} }
} }
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler.txt");
paddle::platform::ResetProfiler();
// 3. Get the feed_target_names and fetch_target_names // 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names = const std::vector<std::string>& feed_target_names =
...@@ -165,6 +169,12 @@ void TestInference(const std::string& dirname, ...@@ -165,6 +169,12 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program // 6. Run the inference program
{ {
// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
// Enable the profiler
paddle::platform::EnableProfiler(state);
// Run repeat times to profile the performance // Run repeat times to profile the performance
for (int i = 0; i < repeat; ++i) { for (int i = 0; i < repeat; ++i) {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
...@@ -173,12 +183,13 @@ void TestInference(const std::string& dirname, ...@@ -173,12 +183,13 @@ void TestInference(const std::string& dirname,
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets);
} }
}
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, paddle::platform::DisableProfiler(
"profiler.txt"); paddle::platform::EventSortingKey::kDefault,
paddle::platform::ResetProfiler(); "run_inference_profiler.txt");
paddle::platform::ResetProfiler();
}
delete scope; delete scope;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册