提交 a4822ed8 编写于 作者: T tensor-tang

add thread setting

上级 53875625
...@@ -19,6 +19,10 @@ limitations under the License. */ ...@@ -19,6 +19,10 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
#ifdef PADDLE_WITH_MKLML
#include <mkl_service.h>
#include <omp.h>
#endif
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_int32(repeat, 100, "Running the inference program repeat times"); DEFINE_int32(repeat, 100, "Running the inference program repeat times");
...@@ -149,6 +153,14 @@ TEST(inference, nlp) { ...@@ -149,6 +153,14 @@ TEST(inference, nlp) {
EnableMKLDNN(inference_program); EnableMKLDNN(inference_program);
} }
#ifdef PADDLE_WITH_MKLML
// only use 1 core per thread
omp_set_dynamic(0);
omp_set_num_threads(1);
mkl_set_num_threads(1);
#endif
double start_ms = 0, stop_ms = 0;
if (FLAGS_num_threads > 1) { if (FLAGS_num_threads > 1) {
std::vector<std::vector<const paddle::framework::LoDTensor*>> jobs; std::vector<std::vector<const paddle::framework::LoDTensor*>> jobs;
bcast_datasets(datasets, &jobs, FLAGS_num_threads); bcast_datasets(datasets, &jobs, FLAGS_num_threads);
...@@ -158,9 +170,11 @@ TEST(inference, nlp) { ...@@ -158,9 +170,11 @@ TEST(inference, nlp) {
std::ref(inference_program), std::ref(inference_program),
std::ref(jobs))); std::ref(jobs)));
} }
start_ms = get_current_ms();
for (int i = 0; i < FLAGS_num_threads; ++i) { for (int i = 0; i < FLAGS_num_threads; ++i) {
threads[i]->join(); threads[i]->join();
} }
stop_ms = get_current_ms();
} else { } else {
if (FLAGS_prepare_vars) { if (FLAGS_prepare_vars) {
...@@ -185,16 +199,18 @@ TEST(inference, nlp) { ...@@ -185,16 +199,18 @@ TEST(inference, nlp) {
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets; std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
// for data and run // for data and run
auto start_ms = get_current_ms(); start_ms = get_current_ms();
for (size_t i = 0; i < datasets.size(); ++i) { for (size_t i = 0; i < datasets.size(); ++i) {
feed_targets[feed_target_names[0]] = &(datasets[i]); feed_targets[feed_target_names[0]] = &(datasets[i]);
executor.RunPreparedContext(ctx.get(), scope, &feed_targets, executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
&fetch_targets, !FLAGS_prepare_vars); &fetch_targets, !FLAGS_prepare_vars);
} }
auto stop_ms = get_current_ms(); stop_ms = get_current_ms();
LOG(INFO) << "Total infer time: " << (stop_ms - start_ms) / 1000.0 / 60
<< " min, avg time per seq: "
<< (stop_ms - start_ms) / datasets.size() << " ms";
} }
LOG(INFO) << "Total inference time with " << FLAGS_num_threads
<< " threads : " << (stop_ms - start_ms) / 1000.0
<< " sec, avg time per seq: "
<< (stop_ms - start_ms) / datasets.size() << " ms";
delete scope; delete scope;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册