提交 e3a96300 编写于 作者: T tensor-tang

move SetNumThreads to platform

上级 b756063c
......@@ -101,7 +101,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operator)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece
operator cpu_helper)
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h"
......@@ -115,7 +115,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
#ifndef PADDLE_WITH_MKLDNN
operators::math::SetNumThreads(1);
platform::SetNumThreads(1);
#endif
}
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/pybind/pybind.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
......@@ -33,7 +33,7 @@ namespace inference {
void Init(const std::vector<std::string> argv) {
framework::InitGflags(argv);
operators::math::SetNumThreads(FLAGS_math_num_threads);
platform::SetNumThreads(FLAGS_math_num_threads);
// init devices
std::vector<int> devices;
std::string token;
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_helper.h"
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#endif
......@@ -164,7 +164,7 @@ TEST(inference, nlp) {
// only use 1 thread number per std::thread
omp_set_dynamic(0);
omp_set_num_threads(1);
paddle::operators::math::SetNumThreads(1);
paddle::platform::SetNumThreads(1);
#endif
double start_ms = 0, stop_ms = 0;
......
......@@ -46,18 +46,6 @@ namespace paddle {
namespace operators {
namespace math {
static void SetNumThreads(int num_threads) {
#ifdef PADDLE_USE_OPENBLAS
int real_num_threads = num_threads > 1 ? num_threads : 1;
openblas_set_num_threads(real_num_threads);
#elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1;
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
#else
PADDLE_ENFORCE(false, "To be implemented.");
#endif
}
/**
* Matrix Descriptor of a memory buffer.
*
......
......@@ -28,6 +28,9 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif
namespace paddle {
namespace platform {
void SetNumThreads(int num_threads) {
#ifdef PADDLE_USE_OPENBLAS
int real_num_threads = num_threads > 1 ? num_threads : 1;
openblas_set_num_threads(real_num_threads);
#elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1;
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
#else
PADDLE_ENFORCE(false, "To be implemented.");
#endif
}
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h>
namespace paddle {
namespace platform {
//! Set the number of threads in use.
void SetNumThreads(int num_threads);
} // namespace platform
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "gtest/gtest.h"
TEST(CpuHelper, SetNumThread) {
paddle::platform::SetNumThreads(1);
paddle::platform::SetNumThreads(4);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册