提交 2098b425 编写于 作者: S Sylwester Fraczek

review fixes (Teamcity fails)

test=develop
上级 741cb33b
...@@ -139,7 +139,9 @@ void TestMultiThreadPrediction( ...@@ -139,7 +139,9 @@ void TestMultiThreadPrediction(
} }
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
#ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(static_cast<int>(tid) + 1); platform::set_cur_thread_id(static_cast<int>(tid) + 1);
#endif
// Each thread should have local inputs and outputs. // Each thread should have local inputs and outputs.
// The inputs of each thread are all the same. // The inputs of each thread are all the same.
std::vector<std::vector<PaddleTensor>> inputs_tid = inputs; std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;
......
...@@ -25,14 +25,6 @@ namespace platform { ...@@ -25,14 +25,6 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
...@@ -309,6 +301,14 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) ...@@ -309,6 +301,14 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
p_mutex_.reset(new std::mutex()); p_mutex_.reset(new std::mutex());
} }
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
......
...@@ -39,12 +39,6 @@ limitations under the License. */ ...@@ -39,12 +39,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
...@@ -182,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -182,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
explicit MKLDNNDeviceContext(CPUPlace place); explicit MKLDNNDeviceContext(CPUPlace place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册