diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 4e92d6a17b0f03e96f83f20cd127170005a01b53..fba64efece8b4782dc4566b62949aea4ac74f323 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" +#ifdef PADDLE_WITH_MKLML +#include +#include +#endif DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_int32(repeat, 100, "Running the inference program repeat times"); @@ -149,6 +153,14 @@ TEST(inference, nlp) { EnableMKLDNN(inference_program); } +#ifdef PADDLE_WITH_MKLML + // only use 1 core per thread + omp_set_dynamic(0); + omp_set_num_threads(1); + mkl_set_num_threads(1); +#endif + + double start_ms = 0, stop_ms = 0; if (FLAGS_num_threads > 1) { std::vector> jobs; bcast_datasets(datasets, &jobs, FLAGS_num_threads); @@ -158,9 +170,11 @@ TEST(inference, nlp) { std::ref(inference_program), std::ref(jobs))); } + start_ms = get_current_ms(); for (int i = 0; i < FLAGS_num_threads; ++i) { threads[i]->join(); } + stop_ms = get_current_ms(); } else { if (FLAGS_prepare_vars) { @@ -185,16 +199,18 @@ TEST(inference, nlp) { std::map feed_targets; // for data and run - auto start_ms = get_current_ms(); + start_ms = get_current_ms(); for (size_t i = 0; i < datasets.size(); ++i) { feed_targets[feed_target_names[0]] = &(datasets[i]); executor.RunPreparedContext(ctx.get(), scope, &feed_targets, &fetch_targets, !FLAGS_prepare_vars); } - auto stop_ms = get_current_ms(); - LOG(INFO) << "Total infer time: " << (stop_ms - start_ms) / 1000.0 / 60 - << " min, avg time per seq: " - << (stop_ms - start_ms) / datasets.size() << " ms"; + stop_ms = get_current_ms(); } + + LOG(INFO) << "Total inference time with " << FLAGS_num_threads + << " threads : " << (stop_ms - start_ms) / 1000.0 + << " sec, avg time per seq: " + << (stop_ms - start_ms) / datasets.size() << " ms"; delete scope; }