提交 8f5fffca 编写于 作者: L Leo Zhao 提交者: Tao Luo

rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id() (#18453)

* rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id()

test=develop

* update session id definition and adjust logic for default behavior

test=develop

* reset logic in mkldnn reuse as most of cases work in default.

test=develop
上级 41ab76e5
......@@ -81,7 +81,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format()));
if (platform::get_cur_thread_id() != -1) {
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
......
......@@ -48,7 +48,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
if (platform::get_cur_thread_id() != -1) {
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
......
......@@ -401,12 +401,12 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
}
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
// Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
......@@ -415,7 +415,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id();
int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
......@@ -448,7 +448,7 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id();
int tid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
......
......@@ -381,8 +381,13 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
// default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
......
......@@ -39,7 +39,8 @@ class MKLDNNHandler {
std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_thread_id() == -1) {
if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
key_ = key_common_;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册