From 6181accaac1556d2ae379d512ac2d146400d2c06 Mon Sep 17 00:00:00 2001 From: wangliu Date: Fri, 13 Jul 2018 12:02:08 +0800 Subject: [PATCH] add interface to dynamic set omp thread num --- src/io/executor.cpp | 11 ----------- src/io/executor.h | 2 -- src/io/paddle_mobile.cpp | 10 ++++++++++ src/io/paddle_mobile.h | 4 ++++ test/net/test_googlenet.cpp | 31 +++++++++++++------------------ 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 510fc8d7db..480f48290c 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -26,9 +26,6 @@ limitations under the License. */ #include "framework/program/var_desc.h" #include "framework/scope.h" #include "framework/tensor.h" -#ifdef _OPENMP -#include -#endif // _OPENMP #ifdef PADDLE_EXECUTOR_MULTITHREAD #include #include @@ -407,14 +404,6 @@ std::vector::Ptype> Executor::Predict( return result_vector; } -template -void Executor::SetThreadNum(int num) { -#ifdef _OPENMP - // omp_set_dynamic(0); - omp_set_num_threads(num); -#endif -} - template class Executor; template class Executor; template class Executor; diff --git a/src/io/executor.h b/src/io/executor.h index 28b0d65181..f8f2a8ad56 100644 --- a/src/io/executor.h +++ b/src/io/executor.h @@ -58,8 +58,6 @@ class Executor { std::vector Predict(const std::vector &input, const std::vector &dims); - void SetThreadNum(int num); - protected: Executor() = default; void InitMemory(); diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 3d5735f8da..cabdd799a0 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -16,6 +16,14 @@ limitations under the License. */ namespace paddle_mobile { +template +void PaddleMobile::SetThreadNum(int num) { +#ifdef _OPENMP + // omp_set_dynamic(0); + omp_set_num_threads(num); +#endif +}; + template bool PaddleMobile::Load(const std::string &dirname, bool optimize, int batch_size) { @@ -81,7 +89,9 @@ PaddleMobile::~PaddleMobile() { } template class PaddleMobile; + template class PaddleMobile; + template class PaddleMobile; } // namespace paddle_mobile diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 3ce39e0ae1..74c1147156 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -17,6 +17,9 @@ limitations under the License. */ #include #include #include +#ifdef _OPENMP +#include +#endif // _OPENMP #include "common/types.h" #include "framework/tensor.h" @@ -44,6 +47,7 @@ class PaddleMobile { * */ bool Load(const std::string &model_path, const std::string ¶_path, bool optimize = false, int batch_size = 1); + void SetThreadNum(int num); /* * @b to predict diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index 1851f2668d..2ab2473639 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -17,26 +17,21 @@ limitations under the License. */ #include "../test_include.h" int main() { - paddle_mobile::Loader loader; + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); bool optimize = true; auto time1 = time(); - // auto program = loader.Load(g_googlenet, optimize); - auto program = loader.Load(g_googlenet_combine + "/model", - g_googlenet_combine + "/params", optimize); - auto time2 = time(); - DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; - paddle_mobile::Executor executor(program, 1, optimize); - executor.SetThreadNum(4); - std::vector input; - std::vector dims{1, 3, 224, 224}; - GetInput(g_test_image_1x3x224x224, &input, dims); - auto time3 = time(); - int count = 1; - for (int i = 0; i < count; ++i) { - executor.Predict(input, dims); - } + if (paddle_mobile.Load(g_googlenet, optimize)) { + auto time2 = time(); + DLOG << "load cost :" << time_diff(time1, time1) << "ms"; + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224, &input, dims); + auto time3 = time(); + auto vec_result = paddle_mobile.Predict(input, dims); + auto time4 = time(); - auto time4 = time(); - DLOG << "predict cost :" << time_diff(time3, time4) / count << "ms\n"; + DLOG << "predict cost :" << time_diff(time3, time4) << "ms"; + } return 0; } -- GitLab