提交 c00843f4 编写于 作者: T tensor-tang

enable multi-threads

上级 400f5e7c
......@@ -25,6 +25,12 @@ DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
DEFINE_bool(prepare_context, true, "Prepare Context before executor");
inline double get_current_ms() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
}
TEST(inference, understand_sentiment) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
......@@ -102,4 +108,10 @@ TEST(inference, understand_sentiment) {
}
}));
}
auto start_ms = get_current_ms();
for (int i = 0; i < num_threads; ++i) {
infer_threads[i]->join();
}
auto stop_ms = get_current_ms();
LOG(INFO) << "total: " << stop_ms - start_ms << " ms";
}
......@@ -156,27 +156,10 @@ void TestInference(const std::string& dirname,
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// Profile the performance
paddle::platform::ProfilerState state;
if (paddle::platform::is_cpu_place(place)) {
state = paddle::platform::ProfilerState::kCPU;
} else {
#ifdef PADDLE_WITH_CUDA
state = paddle::platform::ProfilerState::kAll;
// The default device_id of paddle::platform::CUDAPlace is 0.
// Users can get the device_id using:
// int device_id = place.GetDeviceId();
paddle::platform::SetDeviceId(0);
#else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
#endif
}
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
// Enable the profiler
paddle::platform::EnableProfiler(state);
{
paddle::platform::RecordEvent record_event(
"init_program",
......@@ -189,10 +172,6 @@ void TestInference(const std::string& dirname,
EnableMKLDNN(inference_program);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler");
paddle::platform::ResetProfiler();
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
......@@ -233,9 +212,6 @@ void TestInference(const std::string& dirname,
true, CreateVars);
}
// Enable the profiler
paddle::platform::EnableProfiler(state);
// Run repeat times to profile the performance
for (int i = 0; i < repeat; ++i) {
paddle::platform::RecordEvent record_event(
......@@ -252,11 +228,6 @@ void TestInference(const std::string& dirname,
CreateVars);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(
paddle::platform::EventSortingKey::kDefault, "run_inference_profiler");
paddle::platform::ResetProfiler();
}
delete scope;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册