diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 6e44ad04416b25237e973411024f838c8666e56f..59a76ce103c0e30b1a927b14ae9b01bdb7a275ce 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "non-MKLDNN"); - innerTransDataLayoutFromMKLDNN(in_layout, - paddle::platform::get_cur_paddle_data_layout(), - in, out, place); + innerTransDataLayoutFromMKLDNN( + in_layout, + paddle::platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout(), + in, out, place); } void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 245b9728b6cbc821b01449d692ab3bed933749ff..76c53e82315773dfc2d9f1c073e055e35b1fee00 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type, // For NHWC data we need reshape of tensors as MKL-DNN // is expecting NHWC dims description order platform::MatchShapeToLayout(&out, lin, lout); - paddle::platform::set_cur_paddle_data_layout(lin); + paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( + lin); out.set_layout(DataLayout::kMKLDNN); out.set_format(out_format); } else { diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 396c2d664f7b67e7b1f0ab79c8d28a83a735d223..b18f426883402c9f2ec17bb58ad985b41302709a 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -89,7 +89,8 @@ Executor::~Executor() { platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(place_); dev_ctx->ResetBlobMap(); - platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW); + platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( + paddle::framework::DataLayout::kNCHW); } #endif } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5f9be4f2cb85b6b6839a2c4f7f15eb9464f58cad..055fc8f707d3c9a9f1f1dc28c3d7b6478a0ee1ef 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData( if ((tensor_in->layout() == DataLayout::kMKLDNN) && (var->IsType() == true) && (expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) && - (paddle::platform::get_cur_paddle_data_layout() == - DataLayout::kNHWC)) { + (paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == DataLayout::kNHWC)) { // Mixed execution : MKL-DNN and GPU is not supported! if (!new_scope) { new_scope = &scope.NewScope(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index bbe5596e19e4e0ef02687f6a12b5505d9b7a61d1..23c30986c2685cbc273034d46035c45559d04035 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() { void AnalysisPredictor::MkldnnPreSet(const std::vector &inputs) { #ifdef PADDLE_WITH_MKLDNN VLOG(2) << "AnalysisPredictor::Run get_cur_mkldnn_session_id=" - << platform::get_cur_mkldnn_session_id(); + << platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id(); // In cache clearing mode. if (config_.mkldnn_cache_capacity_ > 0) { VLOG(2) << "In mkldnn cache clear mode."; - platform::set_cur_mkldnn_session_id( - platform::kMKLDNNSessionID_CacheClearing); - platform::set_cur_input_shape_cache_capacity( + platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id( + platform::MKLDNNDeviceContextThreadLocals:: + kMKLDNNSessionID_CacheClearing); + platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity( config_.mkldnn_cache_capacity_); // Set current_input_shape for caching dynamic shape. std::stringstream ss; @@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector &inputs) { } } VLOG(2) << "Set input shape=" << ss.str(); - platform::set_cur_input_shape_str(ss.str()); + platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str()); } #endif } @@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() { CHECK_LE(shape_blob_size, static_cast(config_.mkldnn_cache_capacity_)); } - paddle::platform::set_cur_mkldnn_session_id( - platform::kMKLDNNSessionID_Default); - platform::set_cur_input_shape_cache_capacity(0); - platform::set_cur_input_shape_str(""); + paddle::platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id( + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default); + platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(0); + platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(""); } #endif } diff --git a/paddle/fluid/operators/controlflow/fetch_op.cc b/paddle/fluid/operators/controlflow/fetch_op.cc index 5d6434f45f0b9e380ef7e18331368bb59c5a1460..d86b6b48422d94604724303de72f401bfba2e23e 100644 --- a/paddle/fluid/operators/controlflow/fetch_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_op.cc @@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item, // Convert to desired Paddle layout, apart from grads of filter // as params are not a subject to paddle's data_format framework::innerTransDataLayoutFromMKLDNN( - src_item.layout(), - fetch_var_name == framework::GradVarName("Filter") - ? framework::DataLayout::kNCHW - : paddle::platform::get_cur_paddle_data_layout(), + src_item.layout(), fetch_var_name == framework::GradVarName("Filter") + ? framework::DataLayout::kNCHW + : paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout(), src_item, &out, platform::CPUPlace()); TensorCopySync(out, platform::CPUPlace(), dst_item); } else { diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 68d6d04b0ebb428096da818746eaa3fe1a203d99..c6f782046c95271aa4c63106ca3bd00617eaf43c 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // of conv int8 mkl-dnn. Once conv fp32 and conv int8 // are merged/unified, this will disappear std::string key_tid = ""; - if (platform::get_cur_mkldnn_session_id() == - platform::kMKLDNNSessionID_Default) { + if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { key_tid = "-t:" + platform::ThreadIDasStr(); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 219a23846856cc3348ed44cf52a634c2a068017d..eb44080efebaf73900959ae78277cabffccca67d 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) p_mutex_.reset(new std::mutex()); } -namespace { -// Current mkldnn session id. -thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; -// Current data input shape string. -// - For fixed-shape, it's a null string in default. -// - For dynamic-shape, it's user specific. -thread_local std::string cur_input_shape_str = ""; -// the cache capacity of different input shapes for MKLDNN. -// Default 1 means fixed input shape, not dynamic shape. -thread_local int cur_input_shape_cache_capacity = 1; -// Recently registered data_format. This is needed to -// know for converting MKL-DNN Tensor to non MKL-DNN -thread_local paddle::framework::DataLayout cur_paddle_data_layout = - paddle::framework::DataLayout::kNCHW; -} // namespace - -void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } -size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } -void set_cur_input_shape_str(std::string input_shape_str) { +MKLDNNDeviceContextThreadLocals::Body::Body() { + cur_mkldnn_session_id = kMKLDNNSessionID_Default; + cur_input_shape_str = ""; + cur_input_shape_cache_capacity = 1; + cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW; +} + +void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id( + size_t sid) { + cur_mkldnn_session_id = sid; +} +size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) { + return cur_mkldnn_session_id; +} + +void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str( + std::string input_shape_str) { cur_input_shape_str = input_shape_str; } -void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) { +void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity( + int input_shape_cache_capacity) { cur_input_shape_cache_capacity = input_shape_cache_capacity; } -void set_cur_paddle_data_layout(framework::DataLayout dl) { +void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout( + framework::DataLayout dl) { cur_paddle_data_layout = dl; } -framework::DataLayout get_cur_paddle_data_layout(void) { +framework::DataLayout +MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) { return cur_paddle_data_layout; } @@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const { } size_t MKLDNNDeviceContext::GetShapeBlobSize() const { - std::lock_guard lock(*p_mutex_); + std::lock_guard lock(*p_mutex_); BlobMap* pMap = p_blobmap_.get(); - auto map_it = pMap->find(cur_mkldnn_session_id); + auto map_it = pMap->find(tls().cur_mkldnn_session_id); if (map_it == pMap->end()) { LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : " - << cur_mkldnn_session_id; + << tls().cur_mkldnn_session_id; } return map_it->second->size(); } void MKLDNNDeviceContext::SetBlob(const std::string& name, - std::shared_ptr data) const { + BlobPtr_t data) const { BlobMap* pMap = p_blobmap_.get(); - std::shared_ptr sBlob = nullptr; - std::shared_ptr pBlob = nullptr; + BlobPtr_t sBlob = nullptr; + BlobPtr_t pBlob = nullptr; - int sid = platform::get_cur_mkldnn_session_id(); + int sid = tls().get_cur_mkldnn_session_id(); - std::lock_guard lock(*p_mutex_); + std::lock_guard lock(*p_mutex_); // Find ShapeBlob for current mkldnn session id. auto map_it = pMap->find(sid); if (map_it == pMap->end()) { // 1st time to set blob in current thread - sBlob = std::shared_ptr(new ShapeBlob()); + sBlob = std::make_shared(); (*pMap)[sid] = sBlob; VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n"; } else { @@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, } // Find KeyBlob for current input shape - auto key_it = sBlob->find(cur_input_shape_str); + auto key_it = sBlob->find(tls().cur_input_shape_str); if (key_it == sBlob->end()) { // In cache clearing mode, cur_input_shape_cache_capacity defines // max pblob capacity - if ((static_cast(sid) == kMKLDNNSessionID_CacheClearing) && + if ((static_cast(sid) == + MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) && sBlob->size() && (sBlob->size() >= - static_cast(cur_input_shape_cache_capacity))) { + static_cast(tls().cur_input_shape_cache_capacity))) { VLOG(2) << "sid=" << sid << ", remove all blobs of shape: " << sBlob->begin()->first; sBlob->erase(sBlob->begin()->first); } - pBlob = std::shared_ptr(new KeyBlob()); - (*sBlob)[cur_input_shape_str] = pBlob; + pBlob = std::make_shared(); + (*sBlob)[tls().cur_input_shape_str] = pBlob; } else { pBlob = key_it->second; } @@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, return; } -std::shared_ptr MKLDNNDeviceContext::GetBlob( +MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( const std::string& name) const { BlobMap* pMap = p_blobmap_.get(); - std::shared_ptr sBlob = nullptr; - std::shared_ptr pBlob = nullptr; + BlobPtr_t sBlob = nullptr; + BlobPtr_t pBlob = nullptr; - int sid = platform::get_cur_mkldnn_session_id(); + int sid = tls().get_cur_mkldnn_session_id(); - std::lock_guard lock(*p_mutex_); + std::lock_guard lock(*p_mutex_); // Find ShapeBlob for current mkldnn session id firstly auto map_it = pMap->find(sid); @@ -497,9 +499,9 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( sBlob = map_it->second; // Find KeyBlob for current input shape secondly - auto sBlob_it = sBlob->find(cur_input_shape_str); + auto sBlob_it = sBlob->find(tls().cur_input_shape_str); if (sBlob_it == sBlob->end()) { - VLOG(2) << "GetBlob: sid=" << cur_input_shape_str + VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str << ", miss input_shape_str\n"; return nullptr; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 76fa9ee09b8db40d92b5c47ed4a4cd81eb170dc5..9393ea3e332cb9cc9723a83693725c4c7ed4707c 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -421,30 +421,66 @@ struct DefaultDeviceContextType { #endif #ifdef PADDLE_WITH_MKLDNN -// Following three maps are used to cache MKLDNN primitives. -// There relations are: -// - BlobMap = Map -// - ShapeBlob = Map -// - KeyBlob = Map -// Where: -using KeyBlob = std::unordered_map>; -using ShapeBlob = std::unordered_map>; -using BlobMap = std::unordered_map>; - -// default mkldnn session id -constexpr size_t kMKLDNNSessionID_Default = 0; -// mkldnn session id for cache clearing mode -constexpr size_t kMKLDNNSessionID_CacheClearing = -1; - -void set_cur_mkldnn_session_id(size_t); -size_t get_cur_mkldnn_session_id(void); -void set_cur_input_shape_str(std::string input_shape_str); -void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); -void set_cur_paddle_data_layout(framework::DataLayout); -framework::DataLayout get_cur_paddle_data_layout(void); + +class MKLDNNDeviceContextThreadLocals { + // default mkldnn session id + + typedef MKLDNNDeviceContextThreadLocals self; + struct Body { + size_t cur_mkldnn_session_id; + // Current data input shape string. + // - For fixed-shape, it's a null string in default. + // - For dynamic-shape, it's user specific. + std::string cur_input_shape_str; + // the cache capacity of different input shapes for MKLDNN. + // Default 1 means fixed input shape, not dynamic shape. + int cur_input_shape_cache_capacity; + // Recently registered data_format. This is needed to + // know for converting MKL-DNN Tensor to non MKL-DNN + paddle::framework::DataLayout cur_paddle_data_layout; + + Body(); + void set_cur_mkldnn_session_id(size_t sid); + size_t get_cur_mkldnn_session_id(void); + void set_cur_input_shape_str(std::string input_shape_str); + void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); + void set_cur_paddle_data_layout(framework::DataLayout dl); + framework::DataLayout get_cur_paddle_data_layout(void); + }; + MKLDNNDeviceContextThreadLocals() = default; + MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = + delete; + + public: + // default mkldnn session id + static constexpr size_t kMKLDNNSessionID_Default = 0; + // mkldnn session id for cache clearing mode + static constexpr size_t kMKLDNNSessionID_CacheClearing = -1; + static Body& fetch() { + thread_local Body b; + return b; + } +}; class MKLDNNDeviceContext : public CPUDeviceContext { public: + template + using BlobPtr_t = std::shared_ptr; + template + using umap_value_smart_t = std::unordered_map>; + template + using umap_key_string_t = umap_value_smart_t; + + // Following three maps are used to cache MKLDNN primitives. + // There relations are: + // - BlobMap = Map + // - ShapeBlob = Map + // - KeyBlob = Map + + using KeyBlob = umap_key_string_t; + using ShapeBlob = umap_key_string_t; + using BlobMap = umap_value_smart_t; + explicit MKLDNNDeviceContext(CPUPlace place); /* \brief Get the active engine */ @@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Find a saved blob. Return nullptr if not found std::shared_ptr GetBlob(const std::string& name) const; + static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) { + return MKLDNNDeviceContextThreadLocals::fetch(); + } + private: mkldnn::engine engine_; std::shared_ptr p_blobmap_; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 9c9e7924b396688395c9c4c837faf6f86cc0e165..2fd7e614cc7b811ca3f49ae8386c29d729eaa697 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -42,8 +42,8 @@ class MKLDNNHandlerT { key_common_(base_key), fwd_pd_(nullptr), bwd_pd_(nullptr) { - if (platform::get_cur_mkldnn_session_id() != - platform::kMKLDNNSessionID_Default) { + if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { key_ = key_common_; } else { key_ = key_common_ + "-t:" + ThreadIDasStr(); @@ -177,8 +177,8 @@ class MKLDNNHandler { MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { - if (platform::get_cur_mkldnn_session_id() != - platform::kMKLDNNSessionID_Default) { + if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { key_ = key_common_; } else { key_ = key_common_ + "-t:" + ThreadIDasStr();