未验证 提交 3f3112ce 编写于 作者: T Tao Luo 提交者: GitHub

add shape_blob for cache mkldnn primitive (#18454)

test=develop
上级 d234aa02
...@@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) ...@@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
namespace { namespace {
// Current mkldnn session id. // Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
} // Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
} // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id(); int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread // Find ShapeBlob for current thread
auto map_it = pMap->find(tid); auto map_it = pMap->find(tid);
if (map_it == pMap->end()) { if (map_it == pMap->end()) {
// 1st time to set blob in current thread // 1st time to set blob in current thread
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob()); sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
(*pMap)[tid] = pBlob; (*pMap)[tid] = sBlob;
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n";
} else { } else {
pBlob = map_it->second; sBlob = map_it->second;
} }
// Find Key in found (or newly created) KeyBlob // Find KeyBlob for current input shape
auto key_it = pBlob->find(name); std::string cur_input_shape_str = platform::get_cur_input_shape_str();
auto key_it = sBlob->find(cur_input_shape_str);
if (key_it == pBlob->end()) { if (key_it == sBlob->end()) {
(*pBlob)[name] = data; // create new blob pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob;
} else { } else {
key_it->second = data; // set data to existing blob pBlob = key_it->second;
} }
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
(*pBlob)[name] = data;
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope // lock will be automatically released when out of scope
return; return;
} }
...@@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id(); int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread firstly // Find ShapeBlob for current thread firstly
auto map_it = pMap->find(tid); auto map_it = pMap->find(tid);
if (map_it == pMap->end()) return nullptr; if (map_it == pMap->end()) {
pBlob = map_it->second; VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n";
return nullptr;
}
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str);
if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
pBlob = sBlob_it->second;
// Find Blob via name // Find Blob via name
auto key_it = pBlob->find(name); auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) return nullptr; if (key_it == pBlob->end()) {
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n";
return nullptr;
}
VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope // lock will be automatically released when out of scope
return key_it->second; return key_it->second;
} }
......
...@@ -378,8 +378,15 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -378,8 +378,15 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
// Where:
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>; using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>; using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
// default mkldnn session id // default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0; constexpr size_t kMKLDNNSessionID_Default = 0;
...@@ -388,6 +395,8 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1; ...@@ -388,6 +395,8 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t); void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void); size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
std::string get_cur_input_shape_str(void);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册