提交 8f5fffca 编写于 作者: L Leo Zhao 提交者: Tao Luo

rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id() (#18453)

* rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id()

test=develop

* update session id definition and adjust logic for default behavior

test=develop

* reset logic in mkldnn reuse as most of cases work in default.

test=develop
上级 41ab76e5
...@@ -81,7 +81,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -81,7 +81,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format())); std::to_string(multi_input[0]->format()));
if (platform::get_cur_thread_id() != -1) { if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
std::stringstream ss; std::stringstream ss;
ss << tid; ss << tid;
......
...@@ -48,7 +48,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -48,7 +48,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::MKLDNNHandler::AppendKey(&key, suffix);
if (platform::get_cur_thread_id() != -1) { if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
std::stringstream ss; std::stringstream ss;
ss << tid; ss << tid;
......
...@@ -401,12 +401,12 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) ...@@ -401,12 +401,12 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
} }
namespace { namespace {
// Current thread's id. // Current mkldnn session id.
thread_local int cur_thread_id = 0; thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
} }
void set_cur_thread_id(int tid) { cur_thread_id = tid; } void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
int get_cur_thread_id(void) { return cur_thread_id; } size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
...@@ -415,7 +415,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -415,7 +415,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id(); int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
...@@ -448,7 +448,7 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( ...@@ -448,7 +448,7 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id(); int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
......
...@@ -381,8 +381,13 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -381,8 +381,13 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>; using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>; using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int); // default mkldnn session id
int get_cur_thread_id(void); constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
......
...@@ -39,7 +39,8 @@ class MKLDNNHandler { ...@@ -39,7 +39,8 @@ class MKLDNNHandler {
std::stringstream ss; std::stringstream ss;
ss << tid; ss << tid;
key_ = key_common_ + "-t:" + ss.str(); key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_thread_id() == -1) { if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
key_ = key_common_; key_ = key_common_;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册