提交 a8e85077 编写于 作者: L Liu Yiqun

Refine the profile codes for inference.

上级 b825c792
......@@ -74,6 +74,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
platform::SetDeviceId(dev_id);
#endif
}
// profile
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
RunImpl(scope, place);
}
......@@ -497,9 +500,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place);
// profile
platform::RecordEvent record_event(Type(), dev_ctx);
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
......
......@@ -115,11 +115,11 @@ void TestInference(const std::string& dirname,
#endif
}
// Enable the profiler
paddle::platform::EnableProfiler(state);
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
// Enable the profiler
paddle::platform::EnableProfiler(state);
{
paddle::platform::RecordEvent record_event(
"init_program",
......@@ -143,6 +143,10 @@ void TestInference(const std::string& dirname,
inference_program = paddle::inference::Load(executor, *scope, dirname);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler.txt");
paddle::platform::ResetProfiler();
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
......@@ -165,6 +169,12 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
// Enable the profiler
paddle::platform::EnableProfiler(state);
// Run repeat times to profile the performance
for (int i = 0; i < repeat; ++i) {
paddle::platform::RecordEvent record_event(
......@@ -173,12 +183,13 @@ void TestInference(const std::string& dirname,
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"profiler.txt");
paddle::platform::ResetProfiler();
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(
paddle::platform::EventSortingKey::kDefault,
"run_inference_profiler.txt");
paddle::platform::ResetProfiler();
}
delete scope;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册