diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 6b03ac7119b117e442e6af34c719c8a4f736bde9..181868977dd8f2568486ed0c4e1f260a69795896 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/pybind/pybind.h" DEFINE_string(devices, "", "The devices to be used which is joined by comma."); @@ -33,7 +33,7 @@ namespace inference { void Init(const std::vector argv) { framework::InitGflags(argv); - operators::math::SetNumThreads(FLAGS_math_num_threads); + platform::SetNumThreads(FLAGS_math_num_threads); // init devices std::vector devices; std::string token; diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 03b0b6946339772ac535b3471d50fbd74554239d..5cc1db12bb71e428d493e7c6f718b1c6ed431858 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #ifdef PADDLE_WITH_MKLML #include #endif @@ -164,7 +164,7 @@ TEST(inference, nlp) { // only use 1 thread number per std::thread omp_set_dynamic(0); omp_set_num_threads(1); - paddle::operators::math::SetNumThreads(1); + paddle::platform::SetNumThreads(1); #endif double start_ms = 0, stop_ms = 0; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 363cddffa5ccb5f426e9ecb710aa6d35ed8b556a..9f6c1e5c35f02cd4bc729eea78b17fac017aa90e 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -29,18 +29,6 @@ namespace paddle { namespace operators { namespace math { -static void SetNumThreads(int num_threads) { -#ifdef PADDLE_USE_OPENBLAS - int real_num_threads = num_threads > 1 ? num_threads : 1; - openblas_set_num_threads(real_num_threads); -#elif defined(PADDLE_WITH_MKLML) - int real_num_threads = num_threads > 1 ? num_threads : 1; - platform::dynload::MKL_Set_Num_Threads(real_num_threads); -#else - PADDLE_ENFORCE(false, "To be implemented."); -#endif -} - /** * Matrix Descriptor of a memory buffer. * diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index edd5a84aa2993fb45ab801c91b798aee0cd1cec4..20037d0764056c2a093af801c9cc1eb788dd46d6 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -28,6 +28,9 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce) +cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() @@ -43,7 +46,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS malloc - place eigen3 stringpiece ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) + place eigen3 stringpiece cpu_helper ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) cc_test(init_test SRCS init_test.cc DEPS device_context) diff --git a/paddle/fluid/platform/cpu_helper.cc b/paddle/fluid/platform/cpu_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..77ecb170111d63f23312d06fa8a8172bc45f2a4e --- /dev/null +++ b/paddle/fluid/platform/cpu_helper.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#endif + +namespace paddle { +namespace platform { + +void SetNumThreads(int num_threads) { +#ifdef PADDLE_USE_OPENBLAS + int real_num_threads = num_threads > 1 ? num_threads : 1; + openblas_set_num_threads(real_num_threads); +#elif defined(PADDLE_WITH_MKLML) + int real_num_threads = num_threads > 1 ? num_threads : 1; + platform::dynload::MKL_Set_Num_Threads(real_num_threads); +#else + PADDLE_ENFORCE(false, "To be implemented."); +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cpu_helper.h b/paddle/fluid/platform/cpu_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..78fc392b632ef92d4ae08de2051041fc0bf6778b --- /dev/null +++ b/paddle/fluid/platform/cpu_helper.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace platform { + +//! Set the number of threads in use. +void SetNumThreads(int num_threads); + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cpu_helper_test.cc b/paddle/fluid/platform/cpu_helper_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc1b2b56cd98ca6259c46a76231dbc99482970c1 --- /dev/null +++ b/paddle/fluid/platform/cpu_helper_test.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/cpu_helper.h" + +#include "gtest/gtest.h" + +TEST(CpuHelper, SetNumThread) { + paddle::platform::SetNumThreads(1); + paddle::platform::SetNumThreads(4); +} diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 475437785fd314ba035089c289fbcf6af87e258a..0b776528414735e8a7c1e3763e7ccb662bb9f285 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" @@ -115,7 +115,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); #ifndef PADDLE_WITH_MKLDNN - operators::math::SetNumThreads(1); + platform::SetNumThreads(1); #endif }