提交 a53e8a8d 编写于 作者: B Brian Liu 提交者: Sylwester Fraczek

Update MKLDNN integration framework to support Paddle multi-instances

Make all blob info saved in global device context to be thread based.
Meanwhile save thread id in thread local storage in ParallelDo
上级 2256fae4
...@@ -25,6 +25,14 @@ namespace platform { ...@@ -25,6 +25,14 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
...@@ -296,38 +304,65 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } ...@@ -296,38 +304,65 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>()); p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex());
} }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p; BlobMap* pMap = p_blobmap_.get();
p = p_blobs_.get(); std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id();
auto it = p->find(name); std::lock_guard<std::mutex> lock(*p_mutex_.get());
if (it == p->end()) { // Find KeyBlob for current thread
(*p)[name] = data; // create new blob auto map_it = pMap->find(tid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*pMap)[tid] = pBlob;
} else { } else {
it->second = data; // set data to existing blob pBlob = map_it->second;
} }
// Find Key in found (or newly created) KeyBlob
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) {
(*pBlob)[name] = data; // create new blob
} else {
key_it->second = data; // set data to existing blob
}
// lock will be automatically released when out of scope
return; return;
} }
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p; BlobMap* pMap = p_blobmap_.get();
p = p_blobs_.get(); std::shared_ptr<KeyBlob> pBlob = nullptr;
auto it = p->find(name); int tid = platform::get_cur_thread_id();
if (it != p->end()) { std::lock_guard<std::mutex> lock(*p_mutex_.get());
return it->second;
} // Find KeyBlob for current thread firstly
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) return nullptr;
pBlob = map_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) return nullptr;
return nullptr; // lock will be automatically released when out of scope
return key_it->second;
} }
#endif #endif
......
...@@ -39,6 +39,12 @@ limitations under the License. */ ...@@ -39,6 +39,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
...@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
private: private:
mkldnn::engine engine_; mkldnn::engine engine_;
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>> std::shared_ptr<BlobMap> p_blobmap_;
p_blobs_; std::shared_ptr<std::mutex> p_mutex_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册