diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 42072895fc66b2765570cc9a87809ab796057749..19c3f532d5dcb7588793fa21fa179f6b48649103 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -139,7 +139,9 @@ void TestMultiThreadPrediction( } for (int tid = 0; tid < num_threads; ++tid) { threads.emplace_back([&, tid]() { +#ifdef PADDLE_WITH_MKLDNN platform::set_cur_thread_id(static_cast(tid) + 1); +#endif // Each thread should have local inputs and outputs. // The inputs of each thread are all the same. std::vector> inputs_tid = inputs; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 690ba55279df7abc4b02d80e62d493996fd3da64..b0de636de46451c8b05546fdbff142f984c2bb43 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -25,14 +25,6 @@ namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -namespace { -// Current thread's id. -thread_local int cur_thread_id = 0; -} - -void set_cur_thread_id(int tid) { cur_thread_id = tid; } -int get_cur_thread_id(void) { return cur_thread_id; } - platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { @@ -309,6 +301,14 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) p_mutex_.reset(new std::mutex()); } +namespace { +// Current thread's id. +thread_local int cur_thread_id = 0; +} + +void set_cur_thread_id(int tid) { cur_thread_id = tid; } +int get_cur_thread_id(void) { return cur_thread_id; } + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { BlobMap* pMap = p_blobmap_.get(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 1527c9f32486936e78af89d0fd7c2a8838834bc7..942e13a724339dc85ed1fc72c11e208ddce36dbb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -39,12 +39,6 @@ limitations under the License. */ namespace paddle { namespace platform { -using KeyBlob = std::unordered_map>; -using BlobMap = std::unordered_map>; - -void set_cur_thread_id(int); -int get_cur_thread_id(void); - class DeviceContext { public: virtual ~DeviceContext() {} @@ -182,6 +176,12 @@ struct DefaultDeviceContextType { #endif #ifdef PADDLE_WITH_MKLDNN +using KeyBlob = std::unordered_map>; +using BlobMap = std::unordered_map>; + +void set_cur_thread_id(int); +int get_cur_thread_id(void); + class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place);