From 5d18967b66af832435856c76db174faf8919fa26 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Thu, 14 Oct 2021 15:24:34 +0800 Subject: [PATCH] Revert "Implemented LRU based cache clearing (#36290)" (#36426) This reverts commit bf748f245eb74ffc86e44853fa9ebad7c858b015. --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 49 ++++---- .../mkldnn/conv_transpose_mkldnn_op.cc | 33 +++--- .../operators/mkldnn/quantize_mkldnn_op.cc | 105 ++++++++++++------ paddle/fluid/platform/device_context.cc | 63 ++++------- paddle/fluid/platform/device_context.h | 15 +-- paddle/fluid/platform/mkldnn_reuse.h | 17 ++- 6 files changed, 146 insertions(+), 136 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 84c989f64e..cce835e6bc 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -78,8 +78,7 @@ class ConvMKLDNNHandlerT mkldnn::convolution_backward_weights>( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), - unique_name)), - is_test_(ctx.Attr("is_test")) { + unique_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->layout(), framework::DataLayout::kMKLDNN, @@ -160,6 +159,7 @@ class ConvMKLDNNHandlerT framework::slice_ddim(filter_dims, 2, filter_dims.size()); const auto ksize = framework::vectorize(filter_data_dims); + const bool is_test = ctx.Attr("is_test"); auto strides_temp = ctx.Attr>("strides"); std::vector strides(begin(strides_temp), end(strides_temp)); @@ -214,8 +214,9 @@ class ConvMKLDNNHandlerT const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; + const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; + float sum_scale = 1.0f; std::vector output_shift_scale; if (platform::is_int8()) @@ -260,8 +261,7 @@ class ConvMKLDNNHandlerT mkldnn::convolution_backward_weights>( dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(in->dims()), - unique_name)), - is_test_(false) { + unique_name)) { if (!this->isBwdCached()) { PADDLE_ENFORCE_EQ( in->layout(), framework::DataLayout::kMKLDNN, @@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT "Wrong format set for output_grad tensor")); PADDLE_ENFORCE_EQ( - is_test_, false, + ctx.Attr("is_test"), false, platform::errors::InvalidArgument( "is_test attribute should be set to False in training phase.")); @@ -557,14 +557,13 @@ class ConvMKLDNNHandlerT framework::vectorize(in_mem->dims()), platform::MKLDNNGetDataType(), in_mem->format()); return this->AcquireMemoryWithReorder( - user_mem_md, mem_md, platform::to_void_cast(in_mem_data), key_mem, - is_test_); + user_mem_md, mem_md, platform::to_void_cast(in_mem_data), key_mem); } else { const std::string target_key_suffix{key_mem_target}; const auto target_mem_p = this->AcquireMemory(target_key_suffix); user_mem_p->set_data_handle(platform::to_void_cast(in_mem_data)); if (user_mem_p != target_mem_p) { - this->AcquireReorder(user_mem_p, target_mem_p); + this->AcquireReorder(user_mem_p, target_mem_p, key_mem); } return target_mem_p; } @@ -572,11 +571,12 @@ class ConvMKLDNNHandlerT std::shared_ptr AcquireWeightsMemoryWithReorder( const framework::Tensor* filter, const int groups, const bool is_conv3d, - const std::vector& scale_data = {1.0f}, int mask = 0) { + const bool is_test, const std::vector& scale_data = {1.0f}, + int mask = 0) { // This is workaround to make execution faster, delete // if statement after including md inside Tensor auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); - if (is_test_ && weights_mem_p) { + if (is_test && weights_mem_p) { return weights_mem_p; } else { const K* filter_data = filter->data(); @@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT return this->AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), "@weights_mem_p", is_test_, - {}, scale_data, mask); + platform::to_void_cast(filter_data), "@weights_mem_p", is_test, {}, + scale_data, mask); } } std::shared_ptr AcquireBiasMemoryWithReorder( - const framework::Tensor* bias, + const framework::Tensor* bias, const bool is_test, const std::vector& scale_data = {1.0f}, int mask = 0) { auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); - if (is_test_ && bias_mem_p) { + if (is_test && bias_mem_p) { return bias_mem_p; } else { const K* bias_data = bias->data(); @@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT return this->AcquireMemoryWithReorder( user_bias_md, this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), "@bias_mem_p", is_test_, {}, + platform::to_void_cast(bias_data), "@bias_mem_p", is_test, {}, scale_data, mask); } } @@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { auto residual_memory_p = this->AcquireResidualMemory(residual_param); dst_memory_p = this->template AcquireDstMemory(output); - this->AcquireReorder(residual_memory_p, dst_memory_p); + this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst"); } else { // Changing ShareDataWith to TensorCopy results in performance drop // on ResNet architectures @@ -651,9 +651,6 @@ class ConvMKLDNNHandlerT } return dst_memory_p; } - - private: - const bool is_test_; }; } // anonymous namespace @@ -698,6 +695,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); + const bool is_test = ctx.Attr("is_test"); const bool is_conv3d = ctx.Attr>("strides").size() == 3U; const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); @@ -714,7 +712,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, ctx.Attr("groups"), is_conv3d); + filter, ctx.Attr("groups"), is_conv3d, is_test); std::shared_ptr dst_memory_p; if (fuse_residual_conn) { @@ -733,7 +731,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { {MKLDNN_ARG_DST, *dst_memory_p}}; if (bias) { - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias); + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } @@ -785,10 +783,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.Attr>("Scale_weights"); const bool is_multi_channel = scale_weights_data.size() > 1; const int& groups = ctx.Attr("groups"); + const bool& is_test = ctx.Attr("is_test"); int mask_reorder = is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, false, scale_weights_data, mask_reorder); + filter, groups, false, is_test, scale_weights_data, mask_reorder); std::shared_ptr dst_memory_p; if (fuse_residual_conn) { @@ -823,7 +822,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { handler.get_int8_bias_scales(ctx); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( - bias, scale_bias_data, mask_reorder); + bias, is_test, scale_bias_data, mask_reorder); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 4c374d72c0..8d43e9f0dc 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -51,10 +51,10 @@ class ConvTransposeMKLDNNHandlerT : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), - unique_name)), - is_test_(ctx.Attr("is_test")) { + unique_name)) { if (!this->isCached()) { - PADDLE_ENFORCE_EQ(is_test_, true, + const bool is_test = ctx.Attr("is_test"); + PADDLE_ENFORCE_EQ(is_test, true, platform::errors::InvalidArgument( "ConvTransposeMKLDNN works only for inference. " "The attribute \'is_test\' value should be set to " @@ -169,8 +169,8 @@ class ConvTransposeMKLDNNHandlerT const mkldnn::primitive_attr conv_trans_attr = CreatePostOps(fuse_activation, fuse_alpha, fuse_beta); - auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; + auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; if (bias) { std::vector bias_tz = framework::vectorize(bias->dims()); const auto bias_md = @@ -231,18 +231,18 @@ class ConvTransposeMKLDNNHandlerT const auto target_src_mem_p = this->AcquireMemory(target_key_suffix); user_src_mem_p->set_data_handle(platform::to_void_cast(input_data)); if (user_src_mem_p != target_src_mem_p) { - this->AcquireReorder(user_src_mem_p, target_src_mem_p); + this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p"); } return target_src_mem_p; } } std::shared_ptr AcquireWeightsMemoryWithReorder( - const framework::Tensor* filter, const int& groups) { + const framework::Tensor* filter, const int& groups, const bool& is_test) { // This is workaround to make execution faster, delete // if statement after including md inside Tensor auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); - if (is_test_ && weights_mem_p) { + if (is_test && weights_mem_p) { return weights_mem_p; } else { const K* filter_data = filter->data(); @@ -277,15 +277,15 @@ class ConvTransposeMKLDNNHandlerT return this->template AcquireMemoryWithReorder( user_src_md, this->fwd_pd_->weights_desc(), - platform::to_void_cast(filter_data), "@weights_mem_p", is_test_, + platform::to_void_cast(filter_data), "@weights_mem_p", is_test, iohw2oihw_reorder); } } std::shared_ptr AcquireBiasMemoryWithReorder( - const framework::Tensor* bias) { + const framework::Tensor* bias, const bool& is_test) { auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); - if (is_test_ && bias_mem_p) { + if (is_test && bias_mem_p) { return bias_mem_p; } else { const K* bias_data = bias->data(); @@ -294,12 +294,9 @@ class ConvTransposeMKLDNNHandlerT MKLDNNMemoryFormat::x); return this->AcquireMemoryWithReorder( user_bias_md, this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), "@bias_mem_p", is_test_); + platform::to_void_cast(bias_data), "@bias_mem_p", is_test); } } - - private: - const bool is_test_; }; template @@ -328,6 +325,8 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); + const bool is_test = ctx.Attr("is_test"); + const auto* input = ctx.Input("Input"); const auto* filter = ctx.Input("Filter"); const auto* bias = @@ -341,7 +340,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { output, unique_name); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, ctx.Attr("groups")); + filter, ctx.Attr("groups"), is_test); std::shared_ptr dst_memory_p = handler.template AcquireDstMemory(output); @@ -353,7 +352,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel { {MKLDNN_ARG_DST, *dst_memory_p}}; if (bias) { - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias); + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); } auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 815af4eaaf..819c0d1550 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -64,46 +64,81 @@ class QuantOpKernel : public framework::OpKernel { bool is_negative_input = ctx.Attr("is_negative_input"); bool bfloat16 = ctx.Attr("bfloat16"); - // TODO(jczaja): Refactor with Acquire API + std::string key = + platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift, + is_negative_input, ctx.OutputName("Output")); + key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + + const std::string key_prim = key + "@r"; + const std::string key_src_mem = key + "@s"; + const std::string key_dst_mem = key + "@d"; + std::shared_ptr src_memory; std::shared_ptr dst_memory; std::shared_ptr reorder_p; - - std::string out_layout = ctx.Attr("output_format"); - MKLDNNMemoryFormat out_format = - platform::data_format_to_memory_format(out_layout); - mkldnn::primitive_attr attri; - int mask = 0; - attri.set_output_scales(mask, {scale_data}); - - if (with_shift) { - mkldnn::post_ops post_operations; - post_operations.append_sum(); - attri.set_post_ops(post_operations); - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - // memset casts scale_shift to unsigned char (uint8_t) internally - std::memset(output_data, scale_shift, output->numel()); - } - - auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, - input->format()); - src_memory = std::make_shared(src_md, engine, - to_void_cast(input_data)); - - std::shared_ptr dst_md; - if (bfloat16) { - platform::SetDstMemoryQuantized( - ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); - } else if (is_negative_input && !with_shift) { - platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, - dst_md, dst_memory, out_format); + reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); + + if (reorder_p == nullptr) { + std::string out_layout = ctx.Attr("output_format"); + MKLDNNMemoryFormat out_format = + platform::data_format_to_memory_format(out_layout); + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, {scale_data}); + + if (with_shift) { + mkldnn::post_ops post_operations; + post_operations.append_sum(); + attri.set_post_ops(post_operations); + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + // memset casts scale_shift to unsigned char (uint8_t) internally + std::memset(output_data, scale_shift, output->numel()); + } + + auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, + input->format()); + src_memory = std::make_shared( + src_md, engine, to_void_cast(input_data)); + + std::shared_ptr dst_md; + if (bfloat16) { + platform::SetDstMemoryQuantized( + ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); + } else if (is_negative_input && !with_shift) { + platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, + dst_md, dst_memory, out_format); + } else { + platform::SetDstMemoryQuantized( + ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); + } + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(*src_memory, *dst_memory, attri)); + reorder_p = std::shared_ptr(new reorder(*reorder_pd)); + + dev_ctx.SetBlob(key_prim, reorder_p); + dev_ctx.SetBlob(key_src_mem, src_memory); + dev_ctx.SetBlob(key_dst_mem, dst_memory); } else { - platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, - dst_md, dst_memory, out_format); + src_memory = std::static_pointer_cast( + dev_ctx.GetBlob(key_src_mem)); + src_memory->set_data_handle(to_void_cast(input_data)); + + dst_memory = std::static_pointer_cast( + dev_ctx.GetBlob(key_dst_mem)); + auto place = ctx.GetPlace(); + + if (bfloat16) { + dst_memory->set_data_handle( + output->mutable_data(place)); + } else if (with_shift || !is_negative_input) { + uint8_t* output_data = output->mutable_data(ctx.GetPlace()); + if (with_shift) std::memset(output_data, scale_shift, output->numel()); + dst_memory->set_data_handle(output_data); + } else { + dst_memory->set_data_handle( + output->mutable_data(ctx.GetPlace())); + } } - auto reorder_pd = std::shared_ptr( - new reorder::primitive_desc(*src_memory, *dst_memory, attri)); - reorder_p = std::shared_ptr(new reorder(*reorder_pd)); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 8c81db8c26..587ad5f37e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -11,12 +11,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include -#include -#ifdef _WIN32 -#include -#else -#include -#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" @@ -672,7 +666,7 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { // of this executor for (auto& s : *p_exec_items_) { for (auto& v : (*s.second)[ptr]) { - (v.first)->second.erase(v.second); + (v.first)->erase(v.second); } s.second->erase(ptr); } @@ -683,27 +677,12 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { } } -std::string MKLDNNDeviceContext::PickLeastUsedShape( - BlobPtr_t sb) const { - auto ancient_one = sb->begin(); - for (auto v = std::next(sb->begin()); v != sb->end(); ++v) { - if (v->second->first < ancient_one->second->first) { - ancient_one = v; - } - } - VLOG(2) << "num_shapes: " << sb->size() - << ", remove all blobs of shape: " << ancient_one->first; - return ancient_one->first; -} - -void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor( - std::string shape_to_be_removed) const { - p_exec_items_->erase(shape_to_be_removed); +void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const { + p_exec_items_->erase(p_exec_items_->begin()); } -void MKLDNNDeviceContext::LinkEntryWithExecutor( - BlobPtr_t> pblob, - KeyBlob::iterator it) const { +void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t pblob, + KeyBlob::iterator it) const { // Take current input shape from TLS // Take current executor addess from TLS // and for this executor's items add the one defined with arguments @@ -740,7 +719,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, BlobPtr_t data) const { BlobMap* pMap = p_blobmap_.get(); BlobPtr_t sBlob = nullptr; - BlobPtr_t> pBlob = nullptr; + BlobPtr_t pBlob = nullptr; int sid = tls().get_cur_mkldnn_session_id(); @@ -769,24 +748,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, sBlob->size() && (sBlob->size() >= static_cast(tls().cur_input_shape_cache_capacity))) { - auto shape_to_be_erased = PickLeastUsedShape(sBlob); - sBlob->erase(shape_to_be_erased); - RemoveShapeEntriesWithExecutor(shape_to_be_erased); + VLOG(2) << "sid=" << sid + << ", remove all blobs of shape: " << sBlob->begin()->first; + sBlob->erase(sBlob->begin()->first); + RemoveShapeEntriesWithExecutor(); } - pBlob = std::make_shared>(); - pBlob->first = __rdtsc(); + pBlob = std::make_shared(); (*sBlob)[tls().cur_input_shape_str] = pBlob; } else { pBlob = key_it->second; - // Update time stamp - pBlob->first = __rdtsc(); } // Find Blob via name - auto blob_it = pBlob->second.find(name); - if (blob_it == pBlob->second.end()) { - auto el = pBlob->second.insert( - std::make_pair(name, data)); // (*pBlob)[name] = data; + auto blob_it = pBlob->find(name); + if (blob_it == pBlob->end()) { + auto el = + pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data; // Register new element in per executor map // to have easily erased when executor terminated LinkEntryWithExecutor(pBlob, el.first); @@ -802,7 +779,7 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const { unsigned int num_entries = 0; for (auto const& l3 : *p_blobmap_) { for (auto const& l2 : *(l3.second)) { - num_entries += (l2.second->second).size(); + num_entries += (l2.second)->size(); } } return num_entries; @@ -812,7 +789,7 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( const std::string& name) const { BlobMap* pMap = p_blobmap_.get(); BlobPtr_t sBlob = nullptr; - BlobPtr_t> pBlob = nullptr; + BlobPtr_t pBlob = nullptr; int sid = tls().get_cur_mkldnn_session_id(); @@ -836,14 +813,12 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( pBlob = sBlob_it->second; // Find Blob via name - auto key_it = pBlob->second.find(name); + auto key_it = pBlob->find(name); - if (key_it == pBlob->second.end()) { + if (key_it == pBlob->end()) { VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; return nullptr; } - // Update timestamp - sBlob_it->second->first = __rdtsc(); // TODO(windows) VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n"; // lock will be automatically released when out of scope diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index ee6bbbf237..13a1040dd1 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -757,20 +757,18 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Following three maps are used to cache MKLDNN primitives. // There relations are: // - BlobMap = Map - // - ShapeBlob = Map> + // - ShapeBlob = Map // - KeyBlob = Map using KeyBlob = umap_key_string_t; - using ShapeBlob = umap_key_string_t>; + using ShapeBlob = umap_key_string_t; using BlobMap = umap_value_smart_t; // Auxillary two-level structure (shape, executor) to easier control // clearing cache objects related to specific executor using ExecKey = void*; - using ExecMapCacheIterPair = - std::pair>, - KeyBlob::iterator>; + using ExecMapCacheIterPair = std::pair, KeyBlob::iterator>; using ExecMap = std::unordered_map>; using ExecShape = std::unordered_map>; @@ -781,11 +779,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext { const mkldnn::engine& GetEngine() const { return tls().get_engine(); } // Register object to currently used executor's map - void LinkEntryWithExecutor( - BlobPtr_t> pblob, - KeyBlob::iterator it) const; - void RemoveShapeEntriesWithExecutor(std::string) const; - std::string PickLeastUsedShape(BlobPtr_t sb) const; + void LinkEntryWithExecutor(BlobPtr_t, KeyBlob::iterator) const; + void RemoveShapeEntriesWithExecutor(void) const; // Remove all entries from the blob map void ResetBlobMap(void* ptr); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 5d725307e5..084b47bb3c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -500,9 +500,18 @@ class MKLDNNHandlerT { } void AcquireReorder(const std::shared_ptr& user_memory_p, - const std::shared_ptr& target_memory_p) { - auto reorder_p = - std::make_shared(*user_memory_p, *target_memory_p); + const std::shared_ptr& target_memory_p, + const std::string& suffix) { + const auto key_reorder_p = key_ + suffix + "reorder_p"; + + auto reorder_p = std::static_pointer_cast( + dev_ctx_.GetBlob(key_reorder_p)); + + if (reorder_p == nullptr) { + reorder_p = + std::make_shared(*user_memory_p, *target_memory_p); + dev_ctx_.SetBlob(key_reorder_p, reorder_p); + } auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); @@ -569,8 +578,6 @@ class MKLDNNHandlerT { std::static_pointer_cast(dev_ctx_.GetBlob(user_key)); user_memory_p->set_data_handle(ptr); - // TODO(jczaja): Here we detect if reorder is cached it means it is needed - // need to change this to get rid of keys auto reorder_p = std::static_pointer_cast( dev_ctx_.GetBlob(key_reorder_p)); if (reorder_p != nullptr) { -- GitLab