From bf748f245eb74ffc86e44853fa9ebad7c858b015 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 13 Oct 2021 08:40:20 +0200 Subject: [PATCH] Implemented LRU based cache clearing (#36290) - Lint - Merge with develop - lint --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 49 ++++---- .../mkldnn/conv_transpose_mkldnn_op.cc | 33 +++--- .../operators/mkldnn/quantize_mkldnn_op.cc | 105 ++++++------------ paddle/fluid/platform/device_context.cc | 63 +++++++---- paddle/fluid/platform/device_context.h | 15 ++- paddle/fluid/platform/mkldnn_reuse.h | 17 +-- 6 files changed, 136 insertions(+), 146 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index cce835e6bc0..84c989f64e4 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -78,7 +78,8 @@ class ConvMKLDNNHandlerT mkldnn::convolution_backward_weights>( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), - unique_name)) { + unique_name)), + is_test_(ctx.Attr("is_test")) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->layout(), framework::DataLayout::kMKLDNN, @@ -159,7 +160,6 @@ class ConvMKLDNNHandlerT framework::slice_ddim(filter_dims, 2, filter_dims.size()); const auto ksize = framework::vectorize(filter_data_dims); - const bool is_test = ctx.Attr("is_test"); auto strides_temp = ctx.Attr>("strides"); std::vector strides(begin(strides_temp), end(strides_temp)); @@ -214,9 +214,8 @@ class ConvMKLDNNHandlerT const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; - + const auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; float sum_scale = 1.0f; std::vector output_shift_scale; if (platform::is_int8()) @@ -261,7 +260,8 @@ class ConvMKLDNNHandlerT mkldnn::convolution_backward_weights>( dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(in->dims()), - unique_name)) { + unique_name)), + is_test_(false) { if (!this->isBwdCached()) { PADDLE_ENFORCE_EQ( in->layout(), framework::DataLayout::kMKLDNN, @@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT "Wrong format set for output_grad tensor")); PADDLE_ENFORCE_EQ( - ctx.Attr("is_test"), false, + is_test_, false, platform::errors::InvalidArgument( "is_test attribute should be set to False in training phase.")); @@ -557,13 +557,14 @@ class ConvMKLDNNHandlerT framework::vectorize(in_mem->dims()), platform::MKLDNNGetDataType(), in_mem->format()); return this->AcquireMemoryWithReorder( - user_mem_md, mem_md, platform::to_void_cast(in_mem_data), key_mem); + user_mem_md, mem_md, platform::to_void_cast(in_mem_data), key_mem, + is_test_); } else { const std::string target_key_suffix{key_mem_target}; const auto target_mem_p = this->AcquireMemory(target_key_suffix); user_mem_p->set_data_handle(platform::to_void_cast(in_mem_data)); if (user_mem_p != target_mem_p) { - this->AcquireReorder(user_mem_p, target_mem_p, key_mem); + this->AcquireReorder(user_mem_p, target_mem_p); } return target_mem_p; } @@ -571,12 +572,11 @@ class ConvMKLDNNHandlerT std::shared_ptr AcquireWeightsMemoryWithReorder( const framework::Tensor* filter, const int groups, const bool is_conv3d, - const bool is_test, const std::vector& scale_data = {1.0f}, - int mask = 0) { + const std::vector& scale_data = {1.0f}, int mask = 0) { // This is workaround to make execution faster, delete // if statement after including md inside Tensor auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); - if (is_test && weights_mem_p) { + if (is_test_ && weights_mem_p) { return weights_mem_p; } else { const K* filter_data = filter->data(); @@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT return this->AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, - scale_data, mask); + platform::to_void_cast(filter_data), "@weights_mem_p", is_test_, + {}, scale_data, mask); } } std::shared_ptr AcquireBiasMemoryWithReorder( - const framework::Tensor* bias, const bool is_test, + const framework::Tensor* bias, const std::vector& scale_data = {1.0f}, int mask = 0) { auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); - if (is_test && bias_mem_p) { + if (is_test_ && bias_mem_p) { return bias_mem_p; } else { const K* bias_data = bias->data(); @@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT return this->AcquireMemoryWithReorder( user_bias_md, this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), "@bias_mem_p", is_test, {}, + platform::to_void_cast(bias_data), "@bias_mem_p", is_test_, {}, scale_data, mask); } } @@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { auto residual_memory_p = this->AcquireResidualMemory(residual_param); dst_memory_p = this->template AcquireDstMemory(output); - this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst"); + this->AcquireReorder(residual_memory_p, dst_memory_p); } else { // Changing ShareDataWith to TensorCopy results in performance drop // on ResNet architectures @@ -651,6 +651,9 @@ class ConvMKLDNNHandlerT } return dst_memory_p; } + + private: + const bool is_test_; }; } // anonymous namespace @@ -695,7 +698,6 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - const bool is_test = ctx.Attr("is_test"); const bool is_conv3d = ctx.Attr>("strides").size() == 3U; const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); @@ -712,7 +714,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, ctx.Attr("groups"), is_conv3d, is_test); + filter, ctx.Attr("groups"), is_conv3d); std::shared_ptr dst_memory_p; if (fuse_residual_conn) { @@ -731,7 +733,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { {MKLDNN_ARG_DST, *dst_memory_p}}; if (bias) { - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } @@ -783,11 +785,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.Attr>("Scale_weights"); const bool is_multi_channel = scale_weights_data.size() > 1; const int& groups = ctx.Attr("groups"); - const bool& is_test = ctx.Attr("is_test"); int mask_reorder = is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, false, is_test, scale_weights_data, mask_reorder); + filter, groups, false, scale_weights_data, mask_reorder); std::shared_ptr dst_memory_p; if (fuse_residual_conn) { @@ -822,7 +823,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { handler.get_int8_bias_scales(ctx); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( - bias, is_test, scale_bias_data, mask_reorder); + bias, scale_bias_data, mask_reorder); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 8d43e9f0dca..4c374d72c04 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -51,10 +51,10 @@ class ConvTransposeMKLDNNHandlerT : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), - unique_name)) { + unique_name)), + is_test_(ctx.Attr("is_test")) { if (!this->isCached()) { - const bool is_test = ctx.Attr("is_test"); - PADDLE_ENFORCE_EQ(is_test, true, + PADDLE_ENFORCE_EQ(is_test_, true, platform::errors::InvalidArgument( "ConvTransposeMKLDNN works only for inference. " "The attribute \'is_test\' value should be set to " @@ -169,8 +169,8 @@ class ConvTransposeMKLDNNHandlerT const mkldnn::primitive_attr conv_trans_attr = CreatePostOps(fuse_activation, fuse_alpha, fuse_beta); - auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; + auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; if (bias) { std::vector bias_tz = framework::vectorize(bias->dims()); const auto bias_md = @@ -231,18 +231,18 @@ class ConvTransposeMKLDNNHandlerT const auto target_src_mem_p = this->AcquireMemory(target_key_suffix); user_src_mem_p->set_data_handle(platform::to_void_cast(input_data)); if (user_src_mem_p != target_src_mem_p) { - this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p"); + this->AcquireReorder(user_src_mem_p, target_src_mem_p); } return target_src_mem_p; } } std::shared_ptr AcquireWeightsMemoryWithReorder( - const framework::Tensor* filter, const int& groups, const bool& is_test) { + const framework::Tensor* filter, const int& groups) { // This is workaround to make execution faster, delete // if statement after including md inside Tensor auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); - if (is_test && weights_mem_p) { + if (is_test_ && weights_mem_p) { return weights_mem_p; } else { const K* filter_data = filter->data(); @@ -277,15 +277,15 @@ class ConvTransposeMKLDNNHandlerT return this->template AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), "@weights_mem_p", is_test, + platform::to_void_cast(filter_data), "@weights_mem_p", is_test_, iohw2oihw_reorder); } } std::shared_ptr AcquireBiasMemoryWithReorder( - const framework::Tensor* bias, const bool& is_test) { + const framework::Tensor* bias) { auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); - if (is_test && bias_mem_p) { + if (is_test_ && bias_mem_p) { return bias_mem_p; } else { const K* bias_data = bias->data(); @@ -294,9 +294,12 @@ class ConvTransposeMKLDNNHandlerT MKLDNNMemoryFormat::x); return this->AcquireMemoryWithReorder( user_bias_md, this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), "@bias_mem_p", is_test); + platform::to_void_cast(bias_data), "@bias_mem_p", is_test_); } } + + private: + const bool is_test_; }; template @@ -325,8 +328,6 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - const bool is_test = ctx.Attr("is_test"); - const auto* input = ctx.Input("Input"); const auto* filter = ctx.Input("Filter"); const auto* bias = @@ -340,7 +341,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { output, unique_name); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, ctx.Attr("groups"), is_test); + filter, ctx.Attr("groups")); std::shared_ptr dst_memory_p = handler.template AcquireDstMemory(output); @@ -352,7 +353,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { {MKLDNN_ARG_DST, *dst_memory_p}}; if (bias) { - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 819c0d15505..815af4eaaf1 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -64,81 +64,46 @@ class QuantOpKernel : public framework::OpKernel { bool is_negative_input = ctx.Attr("is_negative_input"); bool bfloat16 = ctx.Attr("bfloat16"); - std::string key = - platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift, - is_negative_input, ctx.OutputName("Output")); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - - const std::string key_prim = key + "@r"; - const std::string key_src_mem = key + "@s"; - const std::string key_dst_mem = key + "@d"; - + // TODO(jczaja): Refactor with Acquire API std::shared_ptr src_memory; std::shared_ptr dst_memory; std::shared_ptr reorder_p; - reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); - - if (reorder_p == nullptr) { - std::string out_layout = ctx.Attr("output_format"); - MKLDNNMemoryFormat out_format = - platform::data_format_to_memory_format(out_layout); - mkldnn::primitive_attr attri; - int mask = 0; - attri.set_output_scales(mask, {scale_data}); - - if (with_shift) { - mkldnn::post_ops post_operations; - post_operations.append_sum(); - attri.set_post_ops(post_operations); - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - // memset casts scale_shift to unsigned char (uint8_t) internally - std::memset(output_data, scale_shift, output->numel()); - } - - auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, - input->format()); - src_memory = std::make_shared( - src_md, engine, to_void_cast(input_data)); - - std::shared_ptr dst_md; - if (bfloat16) { - platform::SetDstMemoryQuantized( - ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); - } else if (is_negative_input && !with_shift) { - platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, - dst_md, dst_memory, out_format); - } else { - platform::SetDstMemoryQuantized( - ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); - } - auto reorder_pd = std::shared_ptr( - new reorder::primitive_desc(*src_memory, *dst_memory, attri)); - reorder_p = std::shared_ptr(new reorder(*reorder_pd)); - - dev_ctx.SetBlob(key_prim, reorder_p); - dev_ctx.SetBlob(key_src_mem, src_memory); - dev_ctx.SetBlob(key_dst_mem, dst_memory); + + std::string out_layout = ctx.Attr("output_format"); + MKLDNNMemoryFormat out_format = + platform::data_format_to_memory_format(out_layout); + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, {scale_data}); + + if (with_shift) { + mkldnn::post_ops post_operations; + post_operations.append_sum(); + attri.set_post_ops(post_operations); + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + // memset casts scale_shift to unsigned char (uint8_t) internally + std::memset(output_data, scale_shift, output->numel()); + } + + auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, + input->format()); + src_memory = std::make_shared(src_md, engine, + to_void_cast(input_data)); + + std::shared_ptr dst_md; + if (bfloat16) { + platform::SetDstMemoryQuantized( + ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); + } else if (is_negative_input && !with_shift) { + platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, + dst_md, dst_memory, out_format); } else { - src_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_src_mem)); - src_memory->set_data_handle(to_void_cast(input_data)); - - dst_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_dst_mem)); - auto place = ctx.GetPlace(); - - if (bfloat16) { - dst_memory->set_data_handle( - output->mutable_data(place)); - } else if (with_shift || !is_negative_input) { - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - if (with_shift) std::memset(output_data, scale_shift, output->numel()); - dst_memory->set_data_handle(output_data); - } else { - dst_memory->set_data_handle( - output->mutable_data(ctx.GetPlace())); - } + platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, + dst_md, dst_memory, out_format); } + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(*src_memory, *dst_memory, attri)); + reorder_p = std::shared_ptr(new reorder(*reorder_pd)); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 587ad5f37e5..8c81db8c26b 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -11,6 +11,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include +#include +#ifdef _WIN32 +#include +#else +#include +#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" @@ -666,7 +672,7 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { // of this executor for (auto& s : *p_exec_items_) { for (auto& v : (*s.second)[ptr]) { - (v.first)->erase(v.second); + (v.first)->second.erase(v.second); } s.second->erase(ptr); } @@ -677,12 +683,27 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { } } -void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const { - p_exec_items_->erase(p_exec_items_->begin()); +std::string MKLDNNDeviceContext::PickLeastUsedShape( + BlobPtr_t sb) const { + auto ancient_one = sb->begin(); + for (auto v = std::next(sb->begin()); v != sb->end(); ++v) { + if (v->second->first < ancient_one->second->first) { + ancient_one = v; + } + } + VLOG(2) << "num_shapes: " << sb->size() + << ", remove all blobs of shape: " << ancient_one->first; + return ancient_one->first; +} + +void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor( + std::string shape_to_be_removed) const { + p_exec_items_->erase(shape_to_be_removed); } -void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t pblob, - KeyBlob::iterator it) const { +void MKLDNNDeviceContext::LinkEntryWithExecutor( + BlobPtr_t> pblob, + KeyBlob::iterator it) const { // Take current input shape from TLS // Take current executor addess from TLS // and for this executor's items add the one defined with arguments @@ -719,7 +740,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, BlobPtr_t data) const { BlobMap* pMap = p_blobmap_.get(); BlobPtr_t sBlob = nullptr; - BlobPtr_t pBlob = nullptr; + BlobPtr_t> pBlob = nullptr; int sid = tls().get_cur_mkldnn_session_id(); @@ -748,22 +769,24 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, sBlob->size() && (sBlob->size() >= static_cast(tls().cur_input_shape_cache_capacity))) { - VLOG(2) << "sid=" << sid - << ", remove all blobs of shape: " << sBlob->begin()->first; - sBlob->erase(sBlob->begin()->first); - RemoveShapeEntriesWithExecutor(); + auto shape_to_be_erased = PickLeastUsedShape(sBlob); + sBlob->erase(shape_to_be_erased); + RemoveShapeEntriesWithExecutor(shape_to_be_erased); } - pBlob = std::make_shared(); + pBlob = std::make_shared>(); + pBlob->first = __rdtsc(); (*sBlob)[tls().cur_input_shape_str] = pBlob; } else { pBlob = key_it->second; + // Update time stamp + pBlob->first = __rdtsc(); } // Find Blob via name - auto blob_it = pBlob->find(name); - if (blob_it == pBlob->end()) { - auto el = - pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data; + auto blob_it = pBlob->second.find(name); + if (blob_it == pBlob->second.end()) { + auto el = pBlob->second.insert( + std::make_pair(name, data)); // (*pBlob)[name] = data; // Register new element in per executor map // to have easily erased when executor terminated LinkEntryWithExecutor(pBlob, el.first); @@ -779,7 +802,7 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const { unsigned int num_entries = 0; for (auto const& l3 : *p_blobmap_) { for (auto const& l2 : *(l3.second)) { - num_entries += (l2.second)->size(); + num_entries += (l2.second->second).size(); } } return num_entries; @@ -789,7 +812,7 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( const std::string& name) const { BlobMap* pMap = p_blobmap_.get(); BlobPtr_t sBlob = nullptr; - BlobPtr_t pBlob = nullptr; + BlobPtr_t> pBlob = nullptr; int sid = tls().get_cur_mkldnn_session_id(); @@ -813,12 +836,14 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( pBlob = sBlob_it->second; // Find Blob via name - auto key_it = pBlob->find(name); + auto key_it = pBlob->second.find(name); - if (key_it == pBlob->end()) { + if (key_it == pBlob->second.end()) { VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; return nullptr; } + // Update timestamp + sBlob_it->second->first = __rdtsc(); // TODO(windows) VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n"; // lock will be automatically released when out of scope diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 13a1040dd19..ee6bbbf2377 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -757,18 +757,20 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Following three maps are used to cache MKLDNN primitives. // There relations are: // - BlobMap = Map - // - ShapeBlob = Map + // - ShapeBlob = Map> // - KeyBlob = Map using KeyBlob = umap_key_string_t; - using ShapeBlob = umap_key_string_t; + using ShapeBlob = umap_key_string_t>; using BlobMap = umap_value_smart_t; // Auxillary two-level structure (shape, executor) to easier control // clearing cache objects related to specific executor using ExecKey = void*; - using ExecMapCacheIterPair = std::pair, KeyBlob::iterator>; + using ExecMapCacheIterPair = + std::pair>, + KeyBlob::iterator>; using ExecMap = std::unordered_map>; using ExecShape = std::unordered_map>; @@ -779,8 +781,11 @@ class MKLDNNDeviceContext : public CPUDeviceContext { const mkldnn::engine& GetEngine() const { return tls().get_engine(); } // Register object to currently used executor's map - void LinkEntryWithExecutor(BlobPtr_t, KeyBlob::iterator) const; - void RemoveShapeEntriesWithExecutor(void) const; + void LinkEntryWithExecutor( + BlobPtr_t> pblob, + KeyBlob::iterator it) const; + void RemoveShapeEntriesWithExecutor(std::string) const; + std::string PickLeastUsedShape(BlobPtr_t sb) const; // Remove all entries from the blob map void ResetBlobMap(void* ptr); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 084b47bb3c7..5d725307e59 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -500,18 +500,9 @@ class MKLDNNHandlerT { } void AcquireReorder(const std::shared_ptr& user_memory_p, - const std::shared_ptr& target_memory_p, - const std::string& suffix) { - const auto key_reorder_p = key_ + suffix + "reorder_p"; - - auto reorder_p = std::static_pointer_cast( - dev_ctx_.GetBlob(key_reorder_p)); - - if (reorder_p == nullptr) { - reorder_p = - std::make_shared(*user_memory_p, *target_memory_p); - dev_ctx_.SetBlob(key_reorder_p, reorder_p); - } + const std::shared_ptr& target_memory_p) { + auto reorder_p = + std::make_shared(*user_memory_p, *target_memory_p); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -578,6 +569,8 @@ class MKLDNNHandlerT { std::static_pointer_cast(dev_ctx_.GetBlob(user_key)); user_memory_p->set_data_handle(ptr); + // TODO(jczaja): Here we detect if reorder is cached it means it is needed + // need to change this to get rid of keys auto reorder_p = std::static_pointer_cast( dev_ctx_.GetBlob(key_reorder_p)); if (reorder_p != nullptr) { -- GitLab