未验证 提交 3f3112ce 编写于 作者: T Tao Luo 提交者: GitHub

add shape_blob for cache mkldnn primitive (#18454)

test=develop
上级 d234aa02
......@@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
namespace {
// Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
}
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
} // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread
// Find ShapeBlob for current thread
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*pMap)[tid] = pBlob;
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
(*pMap)[tid] = sBlob;
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n";
} else {
pBlob = map_it->second;
sBlob = map_it->second;
}
// Find Key in found (or newly created) KeyBlob
auto key_it = pBlob->find(name);
// Find KeyBlob for current input shape
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
auto key_it = sBlob->find(cur_input_shape_str);
if (key_it == pBlob->end()) {
(*pBlob)[name] = data; // create new blob
if (key_it == sBlob->end()) {
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob;
} else {
key_it->second = data; // set data to existing blob
pBlob = key_it->second;
}
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
(*pBlob)[name] = data;
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope
return;
}
......@@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
// Find KeyBlob for current thread firstly
// Find ShapeBlob for current thread firstly
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) return nullptr;
pBlob = map_it->second;
if (map_it == pMap->end()) {
VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n";
return nullptr;
}
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str);
if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
pBlob = sBlob_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) return nullptr;
if (key_it == pBlob->end()) {
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n";
return nullptr;
}
VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
return key_it->second;
}
......
......@@ -378,8 +378,15 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
// Where:
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
// default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0;
......@@ -388,6 +395,8 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
std::string get_cur_input_shape_str(void);
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册