From e3a96300bbf99ec673943bd65994d32084b2d628 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 27 Jun 2018 15:30:01 +0800 Subject: [PATCH] move SetNumThreads to platform --- paddle/fluid/framework/CMakeLists.txt | 3 +- paddle/fluid/framework/init.cc | 4 +- paddle/fluid/inference/io.cc | 4 +- .../tests/book/test_inference_nlp.cc | 4 +- paddle/fluid/operators/math/blas.h | 12 ------ paddle/fluid/platform/CMakeLists.txt | 3 ++ paddle/fluid/platform/cpu_helper.cc | 42 +++++++++++++++++++ paddle/fluid/platform/cpu_helper.h | 26 ++++++++++++ paddle/fluid/platform/cpu_helper_test.cc | 22 ++++++++++ 9 files changed, 101 insertions(+), 19 deletions(-) create mode 100644 paddle/fluid/platform/cpu_helper.cc create mode 100644 paddle/fluid/platform/cpu_helper.h create mode 100644 paddle/fluid/platform/cpu_helper_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6286dda4a5..63f5c2a7f3 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -101,7 +101,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) -cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operator) +cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece + operator cpu_helper) cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc index a1094976f6..bb34757c1e 100644 --- a/paddle/fluid/framework/init.cc +++ b/paddle/fluid/framework/init.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" @@ -115,7 +115,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); #ifndef PADDLE_WITH_MKLDNN - operators::math::SetNumThreads(1); + platform::SetNumThreads(1); #endif } diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 6b03ac7119..181868977d 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/pybind/pybind.h" DEFINE_string(devices, "", "The devices to be used which is joined by comma."); @@ -33,7 +33,7 @@ namespace inference { void Init(const std::vector argv) { framework::InitGflags(argv); - operators::math::SetNumThreads(FLAGS_math_num_threads); + platform::SetNumThreads(FLAGS_math_num_threads); // init devices std::vector devices; std::string token; diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 03b0b69463..5cc1db12bb 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #ifdef PADDLE_WITH_MKLML #include #endif @@ -164,7 +164,7 @@ TEST(inference, nlp) { // only use 1 thread number per std::thread omp_set_dynamic(0); omp_set_num_threads(1); - paddle::operators::math::SetNumThreads(1); + paddle::platform::SetNumThreads(1); #endif double start_ms = 0, stop_ms = 0; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index a907d6a71b..3c95968ebe 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -46,18 +46,6 @@ namespace paddle { namespace operators { namespace math { -static void SetNumThreads(int num_threads) { -#ifdef PADDLE_USE_OPENBLAS - int real_num_threads = num_threads > 1 ? num_threads : 1; - openblas_set_num_threads(real_num_threads); -#elif defined(PADDLE_WITH_MKLML) - int real_num_threads = num_threads > 1 ? num_threads : 1; - platform::dynload::MKL_Set_Num_Threads(real_num_threads); -#else - PADDLE_ENFORCE(false, "To be implemented."); -#endif -} - /** * Matrix Descriptor of a memory buffer. * diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index b29035bafd..1a95994cf4 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -28,6 +28,9 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce) +cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() diff --git a/paddle/fluid/platform/cpu_helper.cc b/paddle/fluid/platform/cpu_helper.cc new file mode 100644 index 0000000000..77ecb17011 --- /dev/null +++ b/paddle/fluid/platform/cpu_helper.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#endif + +namespace paddle { +namespace platform { + +void SetNumThreads(int num_threads) { +#ifdef PADDLE_USE_OPENBLAS + int real_num_threads = num_threads > 1 ? num_threads : 1; + openblas_set_num_threads(real_num_threads); +#elif defined(PADDLE_WITH_MKLML) + int real_num_threads = num_threads > 1 ? num_threads : 1; + platform::dynload::MKL_Set_Num_Threads(real_num_threads); +#else + PADDLE_ENFORCE(false, "To be implemented."); +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cpu_helper.h b/paddle/fluid/platform/cpu_helper.h new file mode 100644 index 0000000000..78fc392b63 --- /dev/null +++ b/paddle/fluid/platform/cpu_helper.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace platform { + +//! Set the number of threads in use. +void SetNumThreads(int num_threads); + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cpu_helper_test.cc b/paddle/fluid/platform/cpu_helper_test.cc new file mode 100644 index 0000000000..dc1b2b56cd --- /dev/null +++ b/paddle/fluid/platform/cpu_helper_test.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/cpu_helper.h" + +#include "gtest/gtest.h" + +TEST(CpuHelper, SetNumThread) { + paddle::platform::SetNumThreads(1); + paddle::platform::SetNumThreads(4); +} -- GitLab